chore: setup workspace for exchange (#105)

This commit is contained in:
Bradley Axen
2024-10-02 11:05:43 -07:00
committed by GitHub
parent aa70408ff7
commit 92fe8e7008
63 changed files with 5895 additions and 120 deletions

View File

@@ -5,7 +5,7 @@ on:
branches: [main]
jobs:
build:
exchange:
runs-on: ubuntu-latest
steps:
@@ -19,9 +19,82 @@ jobs:
- name: Ruff
run: |
uvx ruff check
uvx ruff format --check
uvx ruff check packages/exchange
uvx ruff format packages/exchange --check
- name: Run tests
working-directory: ./packages/exchange
run: |
uv run pytest tests -m 'not integration'
goose:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install UV
run: curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Source Cargo Environment
run: source $HOME/.cargo/env
- name: Ruff
run: |
uvx ruff check src tests
uvx ruff format src tests --check
- name: Run tests
run: |
uv run pytest tests -m 'not integration'
# This runs integration tests of the OpenAI API, using Ollama to host models.
# This lets us test PRs from forks which can't access secrets like API keys.
ollama:
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
# Only test the lastest python version.
- "3.12"
ollama-model:
# For quicker CI, use a smaller, tool-capable model than the default.
- "qwen2.5:0.5b"
steps:
- uses: actions/checkout@v4
- name: Install UV
run: curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Source Cargo Environment
run: source $HOME/.cargo/env
- name: Set up Python
run: uv python install ${{ matrix.python-version }}
- name: Install Ollama
run: curl -fsSL https://ollama.com/install.sh | sh
- name: Start Ollama
run: |
# Run the background, in a way that survives to the next step
nohup ollama serve > ollama.log 2>&1 &
# Block using the ready endpoint
time curl --retry 5 --retry-connrefused --retry-delay 1 -sf http://localhost:11434
# Tests use OpenAI which does not have a mechanism to pull models. Run a
# simple prompt to (pull and) test the model first.
- name: Test Ollama model
run: ollama run $OLLAMA_MODEL hello || cat ollama.log
env:
OLLAMA_MODEL: ${{ matrix.ollama-model }}
- name: Run Ollama tests
run: uv run pytest tests -m integration -k ollama
working-directory: ./packages/exchange
env:
OLLAMA_MODEL: ${{ matrix.ollama-model }}

50
.github/workflows/publish.yaml vendored Normal file
View File

@@ -0,0 +1,50 @@
name: Publish
# A release on goose will also publish exchange, if it has updated
# This means in some cases we may need to make a bump in goose without other changes to release exchange
on:
release:
types: [published]
jobs:
publish:
permissions:
id-token: write
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Get current version from pyproject.toml
id: get_version
run: |
echo "VERSION=$(grep -m 1 'version =' "pyproject.toml" | awk -F'"' '{print $2}')" >> $GITHUB_ENV
- name: Extract tag version
id: extract_tag
run: |
TAG_VERSION=$(echo "${{ github.event.release.tag_name }}" | sed -E 's/v(.*)/\1/')
echo "TAG_VERSION=$TAG_VERSION" >> $GITHUB_ENV
- name: Check if tag matches version from pyproject.toml
id: check_tag
run: |
if [ "${{ env.TAG_VERSION }}" != "${{ env.VERSION }}" ]; then
echo "::error::Tag version (${{ env.TAG_VERSION }}) does not match version in pyproject.toml (${{ env.VERSION }})."
exit 1
fi
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v1
with:
version: "latest"
- name: Build Package
run: |
uv build -o dist --package goose-ai
uv build -o dist --package ai-exchange
- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
skip-existing: true

View File

@@ -1,47 +0,0 @@
name: PYPI Release
on:
push:
tags:
- 'v*'
jobs:
pypi_release:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install UV
run: curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Source Cargo Environment
run: source $HOME/.cargo/env
- name: Build with UV
run: uvx --from build pyproject-build --installer uv
- name: Check version
id: check_version
run: |
PACKAGE_NAME=$(grep '^name =' pyproject.toml | sed -E 's/name = "(.*)"/\1/')
TAG_VERSION=$(echo "$GITHUB_REF" | sed -E 's/refs\/tags\/v(.+)/\1/')
CURRENT_VERSION=$(curl -s https://pypi.org/pypi/$PACKAGE_NAME/json | jq -r .info.version)
PROJECT_VERSION=$(grep '^version =' pyproject.toml | sed -E 's/version = "(.*)"/\1/')
if [ "$TAG_VERSION" != "$PROJECT_VERSION" ]; then
echo "Tag version does not match version in pyproject.toml"
exit 1
fi
if python -c "from packaging.version import parse as parse_version; exit(0 if parse_version('$TAG_VERSION') > parse_version('$CURRENT_VERSION') else 1)"; then
echo "new_version=true" >> $GITHUB_OUTPUT
else
exit 1
fi
- name: Publish
uses: pypa/gh-action-pypi-publish@v1.4.2
if: steps.check_version.outputs.new_version == 'true'
with:
user: __token__
password: ${{ secrets.PYPI_TOKEN_TEMP }}
packages_dir: ./dist/

View File

@@ -1,44 +1,43 @@
# Contributing
<p>
<a href="#prerequisites">Prerequisites</a>
<a href="#developing-and-testing">Developing and testing</a>
<a href="#building-from-source">Building from source</a>
<a href="#developing-goose-plugins">Developing goose-plugins</a>
<a href="#running-ai-exchange-from-source">Running ai-exchange from source</a>
<a href="#evaluations">Evaluations</a>
<a href="#conventional-commits">Conventional Commits</a>
</p>
We welcome Pull Requests for general contributions. If you have a larger new feature or any questions on how to develop a fix, we recommend you open an [issue][issues] before starting.
## Prerequisites
Goose uses [uv][uv] for dependency management, and formats with [ruff][ruff].
Clone goose and make sure you have installed `uv` to get started. When you use
`uv` below in your local goose directly, it will automatically setup the virtualenv
and install dependencies.
We provide a shortcut to standard commands using [just][just] in our `justfile`.
Goose uses [uv][uv] for dependency management, and formats with [ruff][ruff] - install UV first: https://pypi.org/project/uv/
## Development
## Developing and testing
Now that you have a local environment, you can make edits and run our tests!
Now that you have a local environment, you can make edits and run our tests.
### Run Goose
### Creating plugins
If you've made edits and want to try them out, use
Goose is highly configurable through plugins - it reads in modules that its dependencies install (e.g.`goose-plugins`) and uses those that start with certain prefixes (e.g. `goose.toolkit`) to inject their functionality. For example, you will note that Goose's CLI is actually merged with additional CLI methods that are exported from `goose-plugins`.
```
uv run goose session start
```
If you are building a net new feature, you should try to fit it inside a plugin. Goose and `goose-plugins` both support plugins, but there's an important difference in how contributions to each are reviewed. Use the guidelines below to decide where to contribute:
or other `goose` commands.
**When to Add to Goose**:
If you want to run your local changes but in another directory, you can use the path in
the virtualenv created by uv:
Plugins added directly to Goose are subject to rigorous review. This is because they are part of the core system and need to meet higher standards for stability, performance, and maintainability, often being validated through benchmarking.
```
alias goosedev=`uv run which goose`
```
**When to Add to `goose-plugins`:**
You can then run `goosedev` from another dir and it will use your current changes.
Plugins in `goose-plugins` undergo less detailed reviews and are more modular or experimental. They can prove their value through usage or iteration over time and may be eventually moved over to Goose.
### Run Tests
To see how to add a toolkit, see the [toolkits documentation][adding-toolkit].
To run the test suite against your edges, use `pytest`:
### Running tests
```sh
uv run pytest tests -m "not integration"
```
@@ -49,58 +48,17 @@ or, as a shortcut,
just test
```
## Building from source
## Exchange
If you want to develop features on `goose`:
The lower level generation behind goose is powered by the [`exchange`][ai-exchange] package, also in this repo.
1. Clone Goose:
```bash
git clone git@github.com:block-open-source/goose.git ~/Development/goose
```
2. Get `uv` with `brew install uv`
3. Set up your Python virtualenv:
```bash
cd ~/Development/goose
uv sync
uv venv
```
4. Run the `source` command that follows the `uv venv` command to activate the virtualenv.
5. Run Goose:
```bash
uv run goose session start # or any of goose's commands (e.g. goose --help)
```
### Running from source
When you build from source you may want to run it from elsewhere.
1. Run `uv sync` as above
2. Run ```export goose_dev=`uv run which goose` ```
3. You can use that from anywhere in your system, for example `cd ~/ && $goose_dev session start`, or from your path if you like (advanced users only) to be running the latest.
## Developing goose-plugins
1. Clone the `goose-plugins` repo:
```bash
git clone git@github.com:block-open-source/goose-plugins.git ~/Development/goose-plugins
```
2. Follow the steps for creating a virtualenv in the `goose` section above
3. Install `goose-plugins` in `goose`. This means any changes to `goose-plugins` in this folder will immediately be reflected in `goose`:
```bash
uv add --editable ~/Development/goose-plugins
```
4. Make your changes in `goose-plugins`, add the toolkit to the `profiles.yaml` file and run `uv run goose session --start` to see them in action.
## Running ai-exchange from source
goose depends heavily on the [`ai-exchange`][ai-exchange] project, you can clone that repo and then work on both by running:
Thanks to `uv` workspaces, any changes you make to `exchange` will be reflected in using your local goose. To run tests
for exchange, head to `packages/exchange` and run tests just like above
```sh
uv add --editable <path/to/cloned/exchange>
uv run pytest tests -m "not integration"
```
then when you run goose with `uv run goose session start` it will be running it all from source.
## Evaluations
Given that so much of Goose involves interactions with LLMs, our unit tests only go so far to confirming things work as intended.

View File

@@ -0,0 +1,94 @@
<p align="center">
<a href="https://opensource.org/licenses/Apache-2.0"><img src="https://img.shields.io/badge/License-Apache_2.0-blue.svg"></a>
</p>
<p align="center">
<a href="#example">Example</a>
<a href="#plugins">Plugins</a>
</p>
<p align="center"><strong>Exchange</strong> <em>- a uniform python SDK for message generation with LLMs</em></p>
- Provides a flexible layer for message handling and generation
- Directly integrates python functions into tool calling
- Persistently surfaces errors to the underlying models to support reflection
## Example
> [!NOTE]
> Before you can run this example, you need to setup an API key with
> `export OPENAI_API_KEY=your-key-here`
``` python
from exchange import Exchange, Message, Tool
from exchange.providers import OpenAiProvider
def word_count(text: str):
"""Get the count of words in text
Args:
text (str): The text with words to count
"""
return len(text.split(" "))
ex = Exchange(
provider=OpenAiProvider.from_env(),
model="gpt-4o",
system="You are a helpful assistant.",
tools=[Tool.from_function(word_count)],
)
ex.add(Message.user("Count the number of words in this current message"))
# The model sees it has a word count tool, and should use it along the way to answer
# This will call all the tools as needed until the model replies with the final result
reply = ex.reply()
print(reply.text)
# you can see all the tool calls in the message history
print(ex.messages)
```
## Plugins
*exchange* has a plugin mechanism to add support for additional providers and moderators. If you need a
provider not supported here, we'd be happy to review contributions. But you
can also consider building and using your own plugin.
To create a `Provider` plugin, subclass `exchange.provider.Provider`. You will need to
implement the `complete` method. For example this is what we use as a mock in our tests.
You can see a full implementation example of the [OpenAiProvider][openaiprovider]. We
also generally recommend implementing a `from_env` classmethod to instantiate the provider.
``` python
class MockProvider(Provider):
def __init__(self, sequence: List[Message]):
# We'll use init to provide a preplanned reply sequence
self.sequence = sequence
self.call_count = 0
def complete(
self, model: str, system: str, messages: List[Message], tools: List[Tool]
) -> Message:
output = self.sequence[self.call_count]
self.call_count += 1
return output
```
Then use [python packaging's entrypoints][plugins] to register your plugin.
``` toml
[project.entry-points.'exchange.provider']
example = 'path.to.plugin:ExampleProvider'
```
Your plugin will then be available in your application or other applications built on *exchange*
through:
``` python
from exchange.providers import get_provider
provider = get_provider('example').from_env()
```
[openaiprovider]: src/exchange/providers/openai.py
[plugins]: https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/

View File

@@ -0,0 +1,48 @@
[project]
name = "ai-exchange"
version = "0.9.3"
description = "a uniform python SDK for message generation with LLMs"
readme = "README.md"
requires-python = ">=3.10"
author = [{ name = "Block", email = "ai-oss-tools@block.xyz" }]
packages = [{ include = "exchange", from = "src" }]
dependencies = [
"griffe>=1.1.1",
"attrs>=24.2.0",
"jinja2>=3.1.4",
"tiktoken>=0.7.0",
"httpx>=0.27.0",
"tenacity>=9.0.0",
]
[tool.hatch.build.targets.wheel]
packages = ["src/exchange"]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.uv]
dev-dependencies = ["pytest>=8.3.2", "pytest-vcr>=1.0.2", "codecov>=2.1.13"]
[project.entry-points."exchange.provider"]
openai = "exchange.providers.openai:OpenAiProvider"
azure = "exchange.providers.azure:AzureProvider"
databricks = "exchange.providers.databricks:DatabricksProvider"
anthropic = "exchange.providers.anthropic:AnthropicProvider"
bedrock = "exchange.providers.bedrock:BedrockProvider"
ollama = "exchange.providers.ollama:OllamaProvider"
google = "exchange.providers.google:GoogleProvider"
[project.entry-points."exchange.moderator"]
passive = "exchange.moderators.passive:PassiveModerator"
truncate = "exchange.moderators.truncate:ContextTruncate"
summarize = "exchange.moderators.summarizer:ContextSummarizer"
[project.entry-points."metadata.plugins"]
ai-exchange = "exchange:module_name"
[tool.pytest.ini_options]
markers = [
"integration: marks tests that need to authenticate (deselect with '-m \"not integration\"')",
]

View File

@@ -0,0 +1,9 @@
"""Classes for interacting with the exchange API."""
from exchange.tool import Tool # noqa
from exchange.content import Text, ToolResult, ToolUse # noqa
from exchange.message import Message # noqa
from exchange.exchange import Exchange # noqa
from exchange.checkpoint import CheckpointData, Checkpoint # noqa
module_name = "ai-exchange"

View File

@@ -0,0 +1,67 @@
from copy import deepcopy
from typing import List
from attrs import define, field
@define
class Checkpoint:
"""Checkpoint that counts the tokens in messages between the start and end index"""
start_index: int = field(default=0) # inclusive
end_index: int = field(default=0) # inclusive
token_count: int = field(default=0)
def __deepcopy__(self, _) -> "Checkpoint": # noqa: ANN001
"""
Returns a deep copy of the Checkpoint object.
"""
return Checkpoint(
start_index=self.start_index,
end_index=self.end_index,
token_count=self.token_count,
)
@define
class CheckpointData:
"""Aggregates all information about checkpoints"""
# the total number of tokens in the exchange. this is updated every time a checkpoint is
# added or removed
total_token_count: int = field(default=0)
# in order list of individual checkpoints in the exchange
checkpoints: List[Checkpoint] = field(factory=list)
# the offset to apply to the message index when calculating the last message index
# this is useful because messages on the exchange behave like a queue, where you can only
# pop from the left or right sides. This offset allows us to map the checkpoint indices
# to the correct message index, even if we have popped messages from the left side of
# the exchange in the past. we reset this offset to 0 when we empty the checkpoint data.
message_index_offset: int = field(default=0)
def __deepcopy__(self, memo: dict) -> "CheckpointData":
"""Returns a deep copy of the CheckpointData object."""
return CheckpointData(
total_token_count=self.total_token_count,
checkpoints=deepcopy(self.checkpoints, memo),
message_index_offset=self.message_index_offset,
)
@property
def last_message_index(self) -> int:
if not self.checkpoints:
return -1 # we don't have enough information to know
return self.checkpoints[-1].end_index - self.message_index_offset
def reset(self) -> None:
"""Resets the checkpoint data to its initial state."""
self.checkpoints = []
self.message_index_offset = 0
self.total_token_count = 0
def pop(self, index: int = -1) -> Checkpoint:
"""Removes and returns the checkpoint at the given index."""
popped_checkpoint = self.checkpoints.pop(index)
self.total_token_count = self.total_token_count - popped_checkpoint.token_count
return popped_checkpoint

View File

@@ -0,0 +1,38 @@
from typing import Any, Dict, Optional
from attrs import define, asdict
CONTENT_TYPES = {}
class Content:
def __init_subclass__(cls, **kwargs: Dict[str, Any]) -> None:
super().__init_subclass__(**kwargs)
CONTENT_TYPES[cls.__name__] = cls
def to_dict(self) -> Dict[str, Any]:
data = asdict(self, recurse=True)
data["type"] = self.__class__.__name__
return data
@define
class Text(Content):
text: str
@define
class ToolUse(Content):
id: str
name: str
parameters: Any
is_error: bool = False
error_message: Optional[str] = None
@define
class ToolResult(Content):
tool_use_id: str
output: str
is_error: bool = False

View File

@@ -0,0 +1,336 @@
import json
import traceback
from copy import deepcopy
from typing import Any, Dict, List, Mapping, Tuple
from attrs import define, evolve, field, Factory
from tiktoken import get_encoding
from exchange.checkpoint import Checkpoint, CheckpointData
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.moderators import Moderator
from exchange.moderators.truncate import ContextTruncate
from exchange.providers import Provider, Usage
from exchange.tool import Tool
from exchange.token_usage_collector import _token_usage_collector
def validate_tool_output(output: str) -> None:
"""Validate tool output for the given model"""
max_output_chars = 2**20
max_output_tokens = 16000
encoder = get_encoding("cl100k_base")
if len(output) > max_output_chars or len(encoder.encode(output)) > max_output_tokens:
raise ValueError("This tool call created an output that was too long to handle!")
@define(frozen=True)
class Exchange:
"""An exchange of messages with an LLM
The exchange class is meant to be largely immutable, with only the message list
growing once constructed. Use .replace to alter the model, tools, etc.
The exchange supports tool usage, calling tools and letting the model respond when
using the .reply method. It handles most forms of errors and sends those errors back
to the model, to let it attempt to recover.
"""
provider: Provider
model: str
system: str
moderator: Moderator = field(default=ContextTruncate())
tools: Tuple[Tool] = field(factory=tuple, converter=tuple)
messages: List[Message] = field(factory=list)
checkpoint_data: CheckpointData = field(factory=CheckpointData)
generation_args: dict = field(default=Factory(dict))
@property
def _toolmap(self) -> Mapping[str, Tool]:
return {tool.name: tool for tool in self.tools}
def replace(self, **kwargs: Dict[str, Any]) -> "Exchange":
"""Make a copy of the exchange, replacing any passed arguments"""
# TODO: ensure that the checkpoint data is updated correctly. aka,
# if we replace the messages, we need to update the checkpoint data
# if we change the model, we need to update the checkpoint data (?)
if kwargs.get("messages") is None:
kwargs["messages"] = deepcopy(self.messages)
if kwargs.get("checkpoint_data") is None:
kwargs["checkpoint_data"] = deepcopy(
self.checkpoint_data,
)
return evolve(self, **kwargs)
def add(self, message: Message) -> None:
"""Add a message to the history."""
if self.messages and message.role == self.messages[-1].role:
raise ValueError("Messages in the exchange must alternate between user and assistant")
self.messages.append(message)
def generate(self) -> Message:
"""Generate the next message."""
self.moderator.rewrite(self)
message, usage = self.provider.complete(
self.model,
self.system,
messages=self.messages,
tools=self.tools,
**self.generation_args,
)
self.add(message)
self.add_checkpoints_from_usage(usage) # this has to come after adding the response
# TODO: also call `rewrite` here, as this will make our
# messages *consistently* below the token limit. this currently
# is not the case because we could append a large message after calling
# `rewrite` above.
# self.moderator.rewrite(self)
_token_usage_collector.collect(self.model, usage)
return message
def reply(self, max_tool_use: int = 128) -> Message:
"""Get the reply from the underlying model.
This will process any requests for tool calls, calling them immediately, and
storing the intermediate tool messages in the queue. It will return after the
first response that does not request a tool use
Args:
max_tool_use: The maximum number of tool calls to make before returning. Defaults to 128.
"""
if max_tool_use <= 0:
raise ValueError("max_tool_use must be greater than 0")
response = self.generate()
curr_iter = 1 # generate() already called once
while response.tool_use:
content = []
for tool_use in response.tool_use:
tool_result = self.call_function(tool_use)
content.append(tool_result)
self.add(Message(role="user", content=content))
# We've reached the limit of tool calls - break out of the loop
if curr_iter >= max_tool_use:
# At this point, the most recent message is `Message(role='user', content=ToolResult(...))`
response = Message.assistant(
f"We've stopped executing additional tool cause because we reached the limit of {max_tool_use}",
)
self.add(response)
break
else:
response = self.generate()
curr_iter += 1
return response
def call_function(self, tool_use: ToolUse) -> ToolResult:
"""Call the function indicated by the tool use"""
tool = self._toolmap.get(tool_use.name)
if tool is None or tool_use.is_error:
output = f"ERROR: Failed to use tool {tool_use.id}.\nDo NOT use the same tool name and parameters again - that will lead to the same error." # noqa: E501
if tool_use.is_error:
output += f"\n{tool_use.error_message}"
elif tool is None:
valid_tool_names = ", ".join(self._toolmap.keys())
output += f"\nNo tool exists with the name '{tool_use.name}'. Valid tool names are: {valid_tool_names}"
return ToolResult(tool_use_id=tool_use.id, output=output, is_error=True)
try:
if isinstance(tool_use.parameters, dict):
output = json.dumps(tool.function(**tool_use.parameters))
elif isinstance(tool_use.parameters, list):
output = json.dumps(tool.function(*tool_use.parameters))
else:
raise ValueError(
f"The provided tool parameters, {tool_use.parameters} could not be interpreted as a mapping of arguments." # noqa: E501
)
validate_tool_output(output)
is_error = False
except Exception as e:
tb = traceback.format_exc()
output = str(tb) + "\n" + str(e)
is_error = True
return ToolResult(tool_use_id=tool_use.id, output=output, is_error=is_error)
def add_tool_use(self, tool_use: ToolUse) -> None:
"""Manually add a tool use and corresponding result
This will call the implied function and add an assistant
message requesting the ToolUse and a user message with the ToolResult
"""
tool_result = self.call_function(tool_use)
self.add(Message(role="assistant", content=[tool_use]))
self.add(Message(role="user", content=[tool_result]))
def add_checkpoints_from_usage(self, usage: Usage) -> None:
"""
Add checkpoints to the exchange based on the token counts of the last two
groups of messages, as well as the current token total count of the exchange
"""
# we know we just appended one message as the response from the LLM
# so we need to create two checkpoints as we know the token counts
# of the last two groups of messages:
# 1. from the last checkpoint to the most recent user message
# 2. the most recent assistant message
last_checkpoint_end_index = (
self.checkpoint_data.checkpoints[-1].end_index - self.checkpoint_data.message_index_offset
if len(self.checkpoint_data.checkpoints) > 0
else -1
)
new_start_index = last_checkpoint_end_index + 1
# here, our self.checkpoint_data.total_token_count is the previous total token count from the last time
# that we performed a request. if we subtract this value from the input_tokens from our
# latest response, we know how many tokens our **1** from above is.
first_block_token_count = usage.input_tokens - self.checkpoint_data.total_token_count
second_block_token_count = usage.output_tokens
if len(self.messages) - new_start_index > 1:
# this will occur most of the time, as we will have one new user message and one
# new assistant message.
self.checkpoint_data.checkpoints.append(
Checkpoint(
start_index=new_start_index + self.checkpoint_data.message_index_offset,
# end index below is equivalent to the second last message. why? becuase
# the last message is the assistant message that we add below. we need to also
# track the token count of the user message sent.
end_index=len(self.messages) - 2 + self.checkpoint_data.message_index_offset,
token_count=first_block_token_count,
)
)
self.checkpoint_data.checkpoints.append(
Checkpoint(
start_index=len(self.messages) - 1 + self.checkpoint_data.message_index_offset,
end_index=len(self.messages) - 1 + self.checkpoint_data.message_index_offset,
token_count=second_block_token_count,
)
)
# TODO: check if the front of the checkpoints doesn't overlap with
# the first message. if so, we are missing checkpoint data from
# message[0] to message[checkpoint_data.checkpoints[0].start_index]
# we can fill in this data by performing an extra request and doing some math
self.checkpoint_data.total_token_count = usage.total_tokens
def pop_last_message(self) -> Message:
"""Pop the last message from the exchange, handling checkpoints correctly"""
if (
len(self.checkpoint_data.checkpoints) > 0
and self.checkpoint_data.last_message_index > len(self.messages) - 1
):
raise ValueError("Our checkpoint data is out of sync with our message data")
if (
len(self.checkpoint_data.checkpoints) > 0
and self.checkpoint_data.last_message_index == len(self.messages) - 1
):
# remove the last checkpoint, because we no longer know the token count of it's contents.
# note that this is not the same as reverting to the last checkpoint, as we want to
# keep the messages from the last checkpoint. they will have a new checkpoint created for
# them when we call generate() again
self.checkpoint_data.pop()
self.messages.pop()
def pop_first_message(self) -> Message:
"""Pop the first message from the exchange, handling checkpoints correctly"""
if len(self.messages) == 0:
raise ValueError("There are no messages to pop")
if len(self.checkpoint_data.checkpoints) == 0:
raise ValueError("There must be at least one checkpoint to pop the first message")
# get the start and end indexes of the first checkpoint, use these to remove message
first_checkpoint = self.checkpoint_data.checkpoints[0]
first_checkpoint_start_index = first_checkpoint.start_index - self.checkpoint_data.message_index_offset
# check if the first message is part of the first checkpoint
if first_checkpoint_start_index == 0:
# remove this checkpoint, as it no longer has any messages
self.checkpoint_data.pop(0)
self.messages.pop(0)
self.checkpoint_data.message_index_offset += 1
if len(self.checkpoint_data.checkpoints) == 0:
# we've removed all the checkpoints, so we need to reset the message index offset
self.checkpoint_data.message_index_offset = 0
def pop_last_checkpoint(self) -> Tuple[Checkpoint, List[Message]]:
"""
Reverts the exchange back to the last checkpoint, removing associated messages
"""
removed_checkpoint = self.checkpoint_data.checkpoints.pop()
# pop messages until we reach the start of the next checkpoint
messages = []
while len(self.messages) > removed_checkpoint.start_index - self.checkpoint_data.message_index_offset:
messages.append(self.messages.pop())
return removed_checkpoint, messages
def pop_first_checkpoint(self) -> Tuple[Checkpoint, List[Message]]:
"""
Pop the first checkpoint from the exchange, removing associated messages
"""
if len(self.checkpoint_data.checkpoints) == 0:
raise ValueError("There are no checkpoints to pop")
first_checkpoint = self.checkpoint_data.pop(0)
# remove messages until we reach the start of the next checkpoint
messages = []
stop_at_index = first_checkpoint.end_index - self.checkpoint_data.message_index_offset
for _ in range(stop_at_index + 1): # +1 because it's inclusive
messages.append(self.messages.pop(0))
self.checkpoint_data.message_index_offset += 1
if len(self.checkpoint_data.checkpoints) == 0:
# we've removed all the checkpoints, so we need to reset the message index offset
self.checkpoint_data.message_index_offset = 0
return first_checkpoint, messages
def prepend_checkpointed_message(self, message: Message, token_count: int) -> None:
"""Prepend a message to the exchange, updating the checkpoint data"""
self.messages.insert(0, message)
new_index = max(0, self.checkpoint_data.message_index_offset - 1)
self.checkpoint_data.checkpoints.insert(
0,
Checkpoint(
start_index=new_index,
end_index=new_index,
token_count=token_count,
),
)
self.checkpoint_data.message_index_offset = new_index
def rewind(self) -> None:
if not self.messages:
return
# we remove messages until we find the last user text message
while not (self.messages[-1].role == "user" and type(self.messages[-1].content[-1]) is Text):
self.pop_last_message()
# now we remove that last user text message, putting us at a good point
# to ask the user for their input again
if self.messages:
self.pop_last_message()
@property
def is_allowed_to_call_llm(self) -> bool:
"""
Returns True if the exchange is allowed to call the LLM, False otherwise
"""
# TODO: reconsider whether this function belongs here and whether it is necessary
# Some models will have different requirements than others, so it may be better for
# this to be a required method of the provider instead.
return len(self.messages) > 0 and self.messages[-1].role == "user"
def get_token_usage(self) -> Dict[str, Usage]:
return _token_usage_collector.get_token_usage_group_by_model()

View File

@@ -0,0 +1,121 @@
import inspect
import time
from pathlib import Path
from typing import Any, Dict, List, Literal, Type
from attrs import define, field
from jinja2 import Environment, FileSystemLoader
from exchange.content import CONTENT_TYPES, Content, Text, ToolResult, ToolUse
from exchange.utils import create_object_id
Role = Literal["user", "assistant"]
def validate_role_and_content(instance: "Message", *_: Any) -> None: # noqa: ANN401
if instance.role == "user":
if not (instance.text or instance.tool_result):
raise ValueError("User message must include a Text or ToolResult")
if instance.tool_use:
raise ValueError("User message does not support ToolUse")
elif instance.role == "assistant":
if not (instance.text or instance.tool_use):
raise ValueError("Assistant message must include a Text or ToolUsage")
if instance.tool_result:
raise ValueError("Assistant message does not support ToolResult")
def content_converter(contents: List[Dict[str, Any]]) -> List[Content]:
return [(CONTENT_TYPES[c.pop("type")](**c) if c.__class__ not in CONTENT_TYPES.values() else c) for c in contents]
@define
class Message:
"""A message to or from a language model.
This supports several content types to extend to tool usage and (tbi) images.
We also provide shortcuts for simplified text usage; these two are identical:
```
m = Message(role='user', content=[Text(text='abcd')])
assert m.content[0].text == 'abcd'
m = Message.user('abcd')
assert m.text == 'abcd'
```
"""
role: Role = field(default="user")
id: str = field(factory=lambda: str(create_object_id(prefix="msg")))
created: int = field(factory=lambda: int(time.time()))
content: List[Content] = field(factory=list, validator=validate_role_and_content, converter=content_converter)
def to_dict(self) -> Dict[str, Any]:
return {
"role": self.role,
"id": self.id,
"created": self.created,
"content": [item.to_dict() for item in self.content],
}
@property
def text(self) -> str:
"""The text content of this message."""
result = []
for content in self.content:
if isinstance(content, Text):
result.append(content.text)
return "\n".join(result)
@property
def tool_use(self) -> List[ToolUse]:
"""All tool use content of this message."""
result = []
for content in self.content:
if isinstance(content, ToolUse):
result.append(content)
return result
@property
def tool_result(self) -> List[ToolResult]:
"""All tool result content of this message."""
result = []
for content in self.content:
if isinstance(content, ToolResult):
result.append(content)
return result
@classmethod
def load(
cls: Type["Message"],
filename: str,
role: Role = "user",
**kwargs: Dict[str, Any],
) -> "Message":
"""Load the message from filename relative to where the load is called.
This only supports simplified content, with a single text entry
This is meant to emulate importing code rather than a runtime filesystem. So
if you have a directory of code that contains example.py, and example.py has
a function that calls User.load('example.jinja'), it will look in the same
directory as example.py for the jinja file.
"""
frm = inspect.stack()[1]
mod = inspect.getmodule(frm[0])
base_path = Path(mod.__file__).parent
env = Environment(loader=FileSystemLoader(base_path))
template = env.get_template(filename)
rendered_content = template.render(**kwargs)
return cls(role=role, content=[Text(text=rendered_content)])
@classmethod
def user(cls: Type["Message"], text: str) -> "Message":
return cls(role="user", content=[Text(text)])
@classmethod
def assistant(cls: Type["Message"], text: str) -> "Message":
return cls(role="assistant", content=[Text(text)])

View File

@@ -0,0 +1,13 @@
from functools import cache
from typing import Type
from exchange.moderators.base import Moderator
from exchange.utils import load_plugins
from exchange.moderators.passive import PassiveModerator # noqa
from exchange.moderators.truncate import ContextTruncate # noqa
from exchange.moderators.summarizer import ContextSummarizer # noqa
@cache
def get_moderator(name: str) -> Type[Moderator]:
return load_plugins(group="exchange.moderator")[name]

View File

@@ -0,0 +1,8 @@
from abc import ABC, abstractmethod
from typing import Type
class Moderator(ABC):
@abstractmethod
def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821
pass

View File

@@ -0,0 +1,7 @@
from typing import Type
from exchange.moderators.base import Moderator
class PassiveModerator(Moderator):
def rewrite(self, _: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821
pass

View File

@@ -0,0 +1,9 @@
You are an expert technical summarizer.
During your conversation with the user, you may be asked to summarize the content in you conversational history.
When asked to summarize, you should concisely summarize the conversation giving emphasis to newer content. Newer content will be towards the end of the conversation.
Preferentially keep user supplied content in the summary.
The summary *MUST* include filenames that were touched and/or modified. If the updates occurred more recently, keep the latest modifications made to the files in the summary. If the changes occurred earlier in the chat, briefly summarize the changes and don't include the changes in the summary.
There will likely be json formatted blocks referencing ToolUse and ToolResults. You can ignore ToolUse references, but keep the ToolResult outputs, summarizing as needed and with the same guidelines as above.

View File

@@ -0,0 +1,46 @@
from typing import Type
from exchange import Message
from exchange.checkpoint import CheckpointData
from exchange.moderators import ContextTruncate, PassiveModerator
class ContextSummarizer(ContextTruncate):
def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821
"""Summarize the context history up to the last few messages in the exchange"""
self._update_system_prompt_token_count(exchange)
# TODO: use an offset for summarization
if exchange.checkpoint_data.total_token_count < self.max_tokens:
return
messages_to_summarize = self._get_messages_to_remove(exchange)
num_messages_to_remove = len(messages_to_summarize)
# the llm will throw an error if the last message isn't a user message
if messages_to_summarize[-1].role == "assistant" and (not messages_to_summarize[-1].tool_use):
messages_to_summarize.append(Message.user("Summarize our the above conversation"))
summarizer_exchange = exchange.replace(
system=Message.load("summarizer.jinja").text,
moderator=PassiveModerator(),
model=self.model,
messages=messages_to_summarize,
checkpoint_data=CheckpointData(),
)
# get the summarized content and the tokens associated with this content
summary = summarizer_exchange.reply()
summary_checkpoint = summarizer_exchange.checkpoint_data.checkpoints[-1]
# remove the checkpoints that were summarized from the original exchange
for _ in range(num_messages_to_remove):
exchange.pop_first_message()
# insert summary as first message/checkpoint
if len(exchange.messages) == 0 or exchange.messages[0].role == "assistant":
summary_message = Message.user(summary.text)
else:
summary_message = Message.assistant(summary.text)
exchange.prepend_checkpointed_message(summary_message, summary_checkpoint.token_count)

View File

@@ -0,0 +1,82 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
from exchange.checkpoint import CheckpointData
from exchange.message import Message
from exchange.moderators import PassiveModerator
from exchange.moderators.base import Moderator
if TYPE_CHECKING:
from exchange.exchange import Exchange
# currently this is the point at which we start to truncate, so
# so once we get to this token size the token count will exceed this
# by a little bit.
# TODO: make this configurable for each provider
MAX_TOKENS = 100000
class ContextTruncate(Moderator):
def __init__(
self,
model: Optional[str] = None,
max_tokens: int = MAX_TOKENS,
) -> None:
self.model = model
self.system_prompt_token_count = 0
self.max_tokens = max_tokens
self.last_system_prompt = None
def rewrite(self, exchange: Exchange) -> None:
"""Truncate the exchange messages with a FIFO strategy."""
self._update_system_prompt_token_count(exchange)
if exchange.checkpoint_data.total_token_count < self.max_tokens:
return
messages_to_remove = self._get_messages_to_remove(exchange)
for _ in range(len(messages_to_remove)):
exchange.pop_first_message()
def _update_system_prompt_token_count(self, exchange: Exchange) -> None:
is_different_system_prompt = False
if self.last_system_prompt != exchange.system:
is_different_system_prompt = True
self.last_system_prompt = exchange.system
if not self.system_prompt_token_count or is_different_system_prompt:
# calculate the system prompt tokens (includes functions etc...)
# we use a placeholder message with one token, which we subtract later
# this ensures compatibility with providers that require a user message
_system_token_exchange = exchange.replace(
messages=[Message.user("a")],
checkpoint_data=CheckpointData(),
moderator=PassiveModerator(),
model=self.model if self.model else exchange.model,
)
_system_token_exchange.generate()
last_system_prompt_token_count = self.system_prompt_token_count
self.system_prompt_token_count = _system_token_exchange.checkpoint_data.total_token_count - 1
exchange.checkpoint_data.total_token_count -= last_system_prompt_token_count
exchange.checkpoint_data.total_token_count += self.system_prompt_token_count
def _get_messages_to_remove(self, exchange: Exchange) -> List[Message]:
# this keeps all the messages/checkpoints
throwaway_exchange = exchange.replace(
moderator=PassiveModerator(),
)
# get the messages that we want to remove
messages_to_remove = []
while throwaway_exchange.checkpoint_data.total_token_count > self.max_tokens:
_, messages = throwaway_exchange.pop_first_checkpoint()
messages_to_remove.extend(messages)
while len(throwaway_exchange.messages) > 0 and throwaway_exchange.messages[0].tool_result:
# we would need a corresponding tool use once we resume, so we pop this one off too
# and summarize it as well
_, messages = throwaway_exchange.pop_first_checkpoint()
messages_to_remove.extend(messages)
return messages_to_remove

View File

@@ -0,0 +1,17 @@
from functools import cache
from typing import Type
from exchange.providers.anthropic import AnthropicProvider # noqa
from exchange.providers.base import Provider, Usage # noqa
from exchange.providers.databricks import DatabricksProvider # noqa
from exchange.providers.openai import OpenAiProvider # noqa
from exchange.providers.ollama import OllamaProvider # noqa
from exchange.providers.azure import AzureProvider # noqa
from exchange.providers.google import GoogleProvider # noqa
from exchange.utils import load_plugins
@cache
def get_provider(name: str) -> Type[Provider]:
return load_plugins(group="exchange.provider")[name]

View File

@@ -0,0 +1,158 @@
import os
from typing import Any, Dict, List, Tuple, Type
import httpx
from exchange import Message, Tool
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import raise_for_status
ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages"
retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)
class AnthropicProvider(Provider):
def __init__(self, client: httpx.Client) -> None:
self.client = client
@classmethod
def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider":
url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST)
try:
key = os.environ["ANTHROPIC_API_KEY"]
except KeyError:
raise RuntimeError("Failed to get ANTHROPIC_API_KEY from the environment")
client = httpx.Client(
base_url=url,
headers={
"x-api-key": key,
"content-type": "application/json",
"anthropic-version": "2023-06-01",
},
timeout=httpx.Timeout(60 * 10),
)
return cls(client)
@staticmethod
def get_usage(data: Dict) -> Usage: # noqa: ANN401
usage = data.get("usage")
input_tokens = usage.get("input_tokens")
output_tokens = usage.get("output_tokens")
total_tokens = usage.get("total_tokens")
if total_tokens is None and input_tokens is not None and output_tokens is not None:
total_tokens = input_tokens + output_tokens
return Usage(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
)
@staticmethod
def anthropic_response_to_message(response: Dict) -> Message:
content_blocks = response.get("content", [])
content = []
for block in content_blocks:
if block["type"] == "text":
content.append(Text(text=block["text"]))
elif block["type"] == "tool_use":
content.append(
ToolUse(
id=block["id"],
name=block["name"],
parameters=block["input"],
)
)
return Message(role="assistant", content=content)
@staticmethod
def tools_to_anthropic_spec(tools: Tuple[Tool]) -> List[Dict[str, Any]]:
return [
{
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.parameters,
}
for tool in tools
]
@staticmethod
def messages_to_anthropic_spec(messages: List[Message]) -> List[Dict[str, Any]]:
messages_spec = []
# if messages is empty - just make a default
for message in messages:
converted = {"role": message.role}
for content in message.content:
if isinstance(content, Text):
converted["content"] = [{"type": "text", "text": content.text}]
elif isinstance(content, ToolUse):
converted.setdefault("content", []).append(
{
"type": "tool_use",
"id": content.id,
"name": content.name,
"input": content.parameters,
}
)
elif isinstance(content, ToolResult):
converted.setdefault("content", []).append(
{
"type": "tool_result",
"tool_use_id": content.tool_use_id,
"content": content.output,
}
)
messages_spec.append(converted)
if len(messages_spec) == 0:
converted = {
"role": "user",
"content": [{"type": "text", "text": "Ignore"}],
}
messages_spec.append(converted)
return messages_spec
def complete(
self,
model: str,
system: str,
messages: List[Message],
tools: List[Tool] = [],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
tools_set = set()
unique_tools = []
for tool in tools:
if tool.name not in tools_set:
unique_tools.append(tool)
tools_set.add(tool.name)
payload = dict(
system=system,
model=model,
max_tokens=4096,
messages=self.messages_to_anthropic_spec(messages),
tools=self.tools_to_anthropic_spec(tuple(unique_tools)),
**kwargs,
)
payload = {k: v for k, v in payload.items() if v}
response = self._post(payload)
message = self.anthropic_response_to_message(response)
usage = self.get_usage(response)
return message, usage
@retry_procedure
def _post(self, payload: dict) -> httpx.Response:
response = self.client.post(ANTHROPIC_HOST, json=payload)
return raise_for_status(response).json()

View File

@@ -0,0 +1,45 @@
import os
from typing import Type
import httpx
from exchange.providers import OpenAiProvider
class AzureProvider(OpenAiProvider):
"""Provides chat completions for models hosted by the Azure OpenAI Service"""
def __init__(self, client: httpx.Client) -> None:
super().__init__(client)
@classmethod
def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
try:
url = os.environ["AZURE_CHAT_COMPLETIONS_HOST_NAME"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_HOST_NAME from the environment.")
try:
deployment_name = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME from the environment.")
try:
api_version = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION from the environment.")
try:
key = os.environ["AZURE_CHAT_COMPLETIONS_KEY"]
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_KEY from the environment.")
# format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version
url = f"{url}/openai/deployments/{deployment_name}/"
client = httpx.Client(
base_url=url,
headers={"api-key": key, "Content-Type": "application/json"},
params={"api-version": api_version},
timeout=httpx.Timeout(60 * 10),
)
return cls(client)

View File

@@ -0,0 +1,30 @@
from abc import ABC, abstractmethod
from attrs import define, field
from typing import List, Tuple, Type
from exchange.message import Message
from exchange.tool import Tool
@define(hash=True)
class Usage:
input_tokens: int = field(factory=None)
output_tokens: int = field(default=None)
total_tokens: int = field(default=None)
class Provider(ABC):
@classmethod
def from_env(cls: Type["Provider"]) -> "Provider":
return cls()
@abstractmethod
def complete(
self,
model: str,
system: str,
messages: List[Message],
tools: Tuple[Tool],
) -> Tuple[Message, Usage]:
"""Generate the next message using the specified model"""
pass

View File

@@ -0,0 +1,328 @@
import hashlib
import hmac
import json
import logging
import os
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple, Type
from urllib.parse import quote, urlparse
import httpx
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.providers import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import raise_for_status
from exchange.tool import Tool
SERVICE = "bedrock-runtime"
UTC = timezone.utc
logger = logging.getLogger(__name__)
retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)
class AwsClient(httpx.Client):
def __init__(
self,
aws_region: str,
aws_access_key: str,
aws_secret_key: str,
aws_session_token: Optional[str] = None,
**kwargs: Dict[str, Any],
) -> None:
self.region = aws_region
self.host = f"https://{SERVICE}.{aws_region}.amazonaws.com/"
self.access_key = aws_access_key
self.secret_key = aws_secret_key
self.session_token = aws_session_token
super().__init__(base_url=self.host, timeout=600, **kwargs)
def post(self, path: str, json: Dict, **kwargs: Dict[str, Any]) -> httpx.Response:
signed_headers = self.sign_and_get_headers(
method="POST",
url=path,
payload=json,
service="bedrock",
)
return super().post(url=path, json=json, headers=signed_headers, **kwargs)
def sign_and_get_headers(
self,
method: str,
url: str,
payload: dict,
service: str,
) -> Dict[str, str]:
"""
Sign the request and generate the necessary headers for AWS authentication.
Args:
method (str): HTTP method (e.g., 'GET', 'POST').
url (str): The request URL.
payload (dict): The request payload.
service (str): The AWS service name.
region (str): The AWS region.
access_key (str): The AWS access key.
secret_key (str): The AWS secret key.
session_token (Optional[str]): The AWS session token, if any.
Returns:
Dict[str, str]: The headers required for the request.
"""
def sign(key: bytes, msg: str) -> bytes:
return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
def get_signature_key(key: str, date_stamp: str, region_name: str, service_name: str) -> bytes:
k_date = sign(("AWS4" + key).encode("utf-8"), date_stamp)
k_region = sign(k_date, region_name)
k_service = sign(k_region, service_name)
k_signing = sign(k_service, "aws4_request")
return k_signing
# Convert payload to JSON string
request_parameters = json.dumps(payload)
# Create a date for headers and the credential string
t = datetime.now(UTC)
amz_date = t.strftime("%Y%m%dT%H%M%SZ")
date_stamp = t.strftime("%Y%m%d") # Date w/o time, used in credential scope
# Create canonical URI and headers
parsedurl = urlparse(url)
canonical_uri = quote(parsedurl.path if parsedurl.path else "/", safe="/-_.~")
canonical_headers = f"host:{parsedurl.netloc}\nx-amz-date:{amz_date}\n"
# Create the list of signed headers.
signed_headers = "host;x-amz-date"
if self.session_token:
canonical_headers += "x-amz-security-token:" + self.session_token + "\n"
signed_headers += ";x-amz-security-token"
# Create payload hash
payload_hash = hashlib.sha256(request_parameters.encode("utf-8")).hexdigest()
# Canonical request
canonical_request = f"{method}\n{canonical_uri}\n\n{canonical_headers}\n{signed_headers}\n{payload_hash}"
# Create the string to sign
algorithm = "AWS4-HMAC-SHA256"
credential_scope = f"{date_stamp}/{self.region}/{service}/aws4_request"
string_to_sign = (
f"{algorithm}\n{amz_date}\n{credential_scope}\n"
f'{hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()}'
)
# Create the signing key
signing_key = get_signature_key(self.secret_key, date_stamp, self.region, service)
# Sign the string_to_sign using the signing key
signature = hmac.new(signing_key, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
# Add signing information to the request
authorization_header = (
f"{algorithm} Credential={self.access_key}/{credential_scope}, SignedHeaders={signed_headers}, "
f"Signature={signature}"
)
# Headers
headers = {
"Content-Type": "application/json",
"Authorization": authorization_header,
"X-Amz-date": amz_date.encode(),
"x-amz-content-sha256": payload_hash,
}
if self.session_token:
headers["X-Amz-Security-Token"] = self.session_token
return headers
class BedrockProvider(Provider):
def __init__(self, client: AwsClient) -> None:
self.client = client
@classmethod
def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider":
aws_region = os.environ.get("AWS_REGION", "us-east-1")
try:
aws_access_key = os.environ["AWS_ACCESS_KEY_ID"]
aws_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"]
aws_session_token = os.environ.get("AWS_SESSION_TOKEN")
except KeyError:
raise RuntimeError("Failed to get AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY from the environment")
client = AwsClient(
aws_region=aws_region,
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_session_token=aws_session_token,
)
return cls(client=client)
def complete(
self,
model: str,
system: str,
messages: List[Message],
tools: Tuple[Tool],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
"""
Generate a completion response from the Bedrock gateway.
Args:
model (str): The model identifier.
system (str): The system prompt or configuration.
messages (List[Message]): A list of messages to be processed by the model.
tools (Tuple[Tool]): A tuple of tools to be used in the completion process.
**kwargs: Additional keyword arguments for inference configuration.
Returns:
Tuple[Message, Usage]: A tuple containing the response message and usage data.
"""
inference_config = dict(
temperature=kwargs.pop("temperature", None),
maxTokens=kwargs.pop("max_tokens", None),
stopSequences=kwargs.pop("stop", None),
topP=kwargs.pop("topP", None),
)
inference_config = {k: v for k, v in inference_config.items() if v is not None} or None
converted_messages = [self.message_to_bedrock_spec(message) for message in messages]
converted_system = [dict(text=system)]
tool_config = self.tools_to_bedrock_spec(tools)
payload = dict(
system=converted_system,
inferenceConfig=inference_config,
messages=converted_messages,
toolConfig=tool_config,
**kwargs,
)
payload = {k: v for k, v in payload.items() if v}
path = f"{self.client.host}model/{model}/converse"
response = self._post(payload, path)
response_message = response["output"]["message"]
usage_data = response["usage"]
usage = Usage(
input_tokens=usage_data.get("inputTokens"),
output_tokens=usage_data.get("outputTokens"),
total_tokens=usage_data.get("totalTokens"),
)
return self.response_to_message(response_message), usage
@retry_procedure
def _post(self, payload: Any, path: str) -> dict: # noqa: ANN401
response = self.client.post(path, json=payload)
return raise_for_status(response).json()
@staticmethod
def message_to_bedrock_spec(message: Message) -> dict:
bedrock_content = []
try:
for content in message.content:
if isinstance(content, Text):
bedrock_content.append({"text": content.text})
elif isinstance(content, ToolUse):
for tool_use in message.tool_use:
bedrock_content.append(
{
"toolUse": {
"toolUseId": tool_use.id,
"name": tool_use.name,
"input": tool_use.parameters,
}
}
)
elif isinstance(content, ToolResult):
for tool_result in message.tool_result:
# try to parse the output as json
try:
output = json.loads(tool_result.output)
if isinstance(output, dict):
content = [{"json": output}]
else:
content = [{"text": str(output)}]
except json.JSONDecodeError:
content = [{"text": tool_result.output}]
bedrock_content.append(
{
"toolResult": {
"toolUseId": tool_result.tool_use_id,
"content": content,
**({"status": "error"} if tool_result.is_error else {}),
}
}
)
return {"role": message.role, "content": bedrock_content}
except AttributeError:
raise Exception("Invalid message")
@staticmethod
def response_to_message(response_message: dict) -> Message:
content = []
if response_message["role"] == "user":
for block in response_message["content"]:
if "text" in block:
content.append(Text(block["text"]))
if "toolResult" in block:
content.append(
ToolResult(
tool_use_id=block["toolResult"]["toolResultId"],
output=block["toolResult"]["content"][0]["json"],
is_error=block["toolResult"].get("status") == "error",
)
)
return Message(role="user", content=content)
elif response_message["role"] == "assistant":
for block in response_message["content"]:
if "text" in block:
content.append(Text(block["text"]))
if "toolUse" in block:
content.append(
ToolUse(
id=block["toolUse"]["toolUseId"],
name=block["toolUse"]["name"],
parameters=block["toolUse"]["input"],
)
)
return Message(role="assistant", content=content)
raise Exception("Invalid response")
@staticmethod
def tools_to_bedrock_spec(tools: Tuple[Tool]) -> Optional[dict]:
if len(tools) == 0:
return None # API requires a non-empty tool config or None
tools_added = set()
tool_config_list = []
for tool in tools:
if tool.name in tools_added:
logging.warning(f"Tool {tool.name} already added to tool config. Skipping.")
continue
tool_config_list.append(
{
"toolSpec": {
"name": tool.name,
"description": tool.description,
"inputSchema": {"json": tool.parameters},
}
}
)
tools_added.add(tool.name)
tool_config = {"tools": tool_config_list}
return tool_config

View File

@@ -0,0 +1,102 @@
import os
from typing import Any, Dict, List, Tuple, Type
import httpx
from exchange.message import Message
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import raise_for_status, retry_if_status
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
tools_to_openai_spec,
)
from exchange.tool import Tool
retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)
class DatabricksProvider(Provider):
"""Provides chat completions for models on Databricks serving endpoints
Models are expected to follow the llm/v1/chat "task". This includes support
for foundation and external model endpoints
https://docs.databricks.com/en/machine-learning/model-serving/create-foundation-model-endpoints.html#create-generative-ai-model-serving-endpoints
"""
def __init__(self, client: httpx.Client) -> None:
super().__init__()
self.client = client
@classmethod
def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider":
try:
url = os.environ["DATABRICKS_HOST"]
except KeyError:
raise RuntimeError(
"Failed to get DATABRICKS_HOST from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
)
try:
key = os.environ["DATABRICKS_TOKEN"]
except KeyError:
raise RuntimeError(
"Failed to get DATABRICKS_TOKEN from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields"
)
client = httpx.Client(
base_url=url,
auth=("token", key),
timeout=httpx.Timeout(60 * 10),
)
return cls(client)
@staticmethod
def get_usage(data: dict) -> Usage:
usage = data.pop("usage")
input_tokens = usage.get("prompt_tokens")
output_tokens = usage.get("completion_tokens")
total_tokens = usage.get("total_tokens")
if total_tokens is None and input_tokens is not None and output_tokens is not None:
total_tokens = input_tokens + output_tokens
return Usage(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
)
def complete(
self,
model: str,
system: str,
messages: List[Message],
tools: Tuple[Tool],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
payload = dict(
messages=[
{"role": "system", "content": system},
*messages_to_openai_spec(messages),
],
tools=tools_to_openai_spec(tools) if tools else [],
**kwargs,
)
payload = {k: v for k, v in payload.items() if v}
response = self._post(model, payload)
message = openai_response_to_message(response)
usage = self.get_usage(response)
return message, usage
@retry_procedure
def _post(self, model: str, payload: dict) -> httpx.Response:
response = self.client.post(
f"serving-endpoints/{model}/invocations",
json=payload,
)
return raise_for_status(response).json()

View File

@@ -0,0 +1,154 @@
import os
from typing import Any, Dict, List, Tuple, Type
import httpx
from exchange import Message, Tool
from exchange.content import Text, ToolResult, ToolUse
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import raise_for_status
GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta"
retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)
class GoogleProvider(Provider):
def __init__(self, client: httpx.Client) -> None:
self.client = client
@classmethod
def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider":
url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST)
try:
key = os.environ["GOOGLE_API_KEY"]
except KeyError:
raise RuntimeError(
"Failed to get GOOGLE_API_KEY from the environment, see https://ai.google.dev/gemini-api/docs/api-key"
)
client = httpx.Client(
base_url=url,
headers={
"Content-Type": "application/json",
},
params={"key": key},
timeout=httpx.Timeout(60 * 10),
)
return cls(client)
@staticmethod
def get_usage(data: Dict) -> Usage: # noqa: ANN401
usage = data.get("usageMetadata")
input_tokens = usage.get("promptTokenCount")
output_tokens = usage.get("candidatesTokenCount")
total_tokens = usage.get("totalTokenCount")
if total_tokens is None and input_tokens is not None and output_tokens is not None:
total_tokens = input_tokens + output_tokens
return Usage(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
)
@staticmethod
def google_response_to_message(response: Dict) -> Message:
candidates = response.get("candidates", [])
if candidates:
# Only use first candidate for now
candidate = candidates[0]
content_parts = candidate.get("content", {}).get("parts", [])
content = []
for part in content_parts:
if "text" in part:
content.append(Text(text=part["text"]))
elif "functionCall" in part:
content.append(
ToolUse(
id=part["functionCall"].get("name", ""),
name=part["functionCall"].get("name", ""),
parameters=part["functionCall"].get("args", {}),
)
)
return Message(role="assistant", content=content)
# If no valid candidates were found, return an empty message
return Message(role="assistant", content=[])
@staticmethod
def tools_to_google_spec(tools: Tuple[Tool]) -> Dict[str, List[Dict[str, Any]]]:
if not tools:
return {}
converted_tools = []
for tool in tools:
converted_tool: Dict[str, Any] = {
"name": tool.name,
"description": tool.description or "",
}
if tool.parameters["properties"]:
converted_tool["parameters"] = tool.parameters
converted_tools.append(converted_tool)
return {"functionDeclarations": converted_tools}
@staticmethod
def messages_to_google_spec(messages: List[Message]) -> List[Dict[str, Any]]:
messages_spec = []
for message in messages:
role = "user" if message.role == "user" else "model"
converted = {"role": role, "parts": []}
for content in message.content:
if isinstance(content, Text):
converted["parts"].append({"text": content.text})
elif isinstance(content, ToolUse):
converted["parts"].append({"functionCall": {"name": content.name, "args": content.parameters}})
elif isinstance(content, ToolResult):
converted["parts"].append(
{"functionResponse": {"name": content.tool_use_id, "response": {"content": content.output}}}
)
messages_spec.append(converted)
if not messages_spec:
messages_spec.append({"role": "user", "parts": [{"text": "Ignore"}]})
return messages_spec
def complete(
self,
model: str,
system: str,
messages: List[Message],
tools: List[Tool] = [],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
tools_set = set()
unique_tools = []
for tool in tools:
if tool.name not in tools_set:
unique_tools.append(tool)
tools_set.add(tool.name)
payload = dict(
system_instruction={"parts": [{"text": system}]},
contents=self.messages_to_google_spec(messages),
tools=self.tools_to_google_spec(tuple(unique_tools)),
**kwargs,
)
payload = {k: v for k, v in payload.items() if v}
response = self._post(payload, model)
message = self.google_response_to_message(response)
usage = self.get_usage(response)
return message, usage
@retry_procedure
def _post(self, payload: dict, model: str) -> httpx.Response:
response = self.client.post("models/" + model + ":generateContent", json=payload)
return raise_for_status(response).json()

View File

@@ -0,0 +1,45 @@
import os
from typing import Type
import httpx
from exchange.providers.openai import OpenAiProvider
OLLAMA_HOST = "http://localhost:11434/"
OLLAMA_MODEL = "mistral-nemo"
class OllamaProvider(OpenAiProvider):
"""Provides chat completions for models hosted by Ollama"""
__doc__ += f"""
Here's an example profile configuration to try:
ollama:
provider: ollama
processor: {OLLAMA_MODEL}
accelerator: {OLLAMA_MODEL}
moderator: passive
toolkits:
- name: developer
requires: {{}}
"""
def __init__(self, client: httpx.Client) -> None:
print("PLEASE NOTE: the ollama provider is experimental, use with care")
super().__init__(client)
@classmethod
def from_env(cls: Type["OllamaProvider"]) -> "OllamaProvider":
ollama_url = os.environ.get("OLLAMA_HOST", OLLAMA_HOST)
timeout = httpx.Timeout(60 * 10)
# from_env is expected to fail if required ENV variables are not
# available. Since this provider can run with defaults, we substitute
# an Ollama health check (GET /) to determine if the service is ok.
httpx.get(ollama_url, timeout=timeout)
# When served by Ollama, the OpenAI API is available at the path "v1/".
client = httpx.Client(base_url=ollama_url + "v1/", timeout=timeout)
return cls(client)

View File

@@ -0,0 +1,101 @@
import os
from typing import Any, Dict, List, Tuple, Type
import httpx
from exchange.message import Message
from exchange.providers.base import Provider, Usage
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
openai_single_message_context_length_exceeded,
raise_for_status,
tools_to_openai_spec,
)
from exchange.tool import Tool
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
OPENAI_HOST = "https://api.openai.com/"
retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)
class OpenAiProvider(Provider):
"""Provides chat completions for models hosted directly by OpenAI"""
def __init__(self, client: httpx.Client) -> None:
super().__init__()
self.client = client
@classmethod
def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider":
url = os.environ.get("OPENAI_HOST", OPENAI_HOST)
try:
key = os.environ["OPENAI_API_KEY"]
except KeyError:
raise RuntimeError(
"Failed to get OPENAI_API_KEY from the environment, see https://platform.openai.com/docs/api-reference/api-keys"
)
client = httpx.Client(
base_url=url + "v1/",
auth=("Bearer", key),
timeout=httpx.Timeout(60 * 10),
)
return cls(client)
@staticmethod
def get_usage(data: dict) -> Usage:
usage = data.pop("usage")
input_tokens = usage.get("prompt_tokens")
output_tokens = usage.get("completion_tokens")
total_tokens = usage.get("total_tokens")
if total_tokens is None and input_tokens is not None and output_tokens is not None:
total_tokens = input_tokens + output_tokens
return Usage(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
)
def complete(
self,
model: str,
system: str,
messages: List[Message],
tools: Tuple[Tool],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
system_message = [] if model.startswith("o1") else [{"role": "system", "content": system}]
payload = dict(
messages=system_message + messages_to_openai_spec(messages),
model=model,
tools=tools_to_openai_spec(tools) if tools else [],
**kwargs,
)
payload = {k: v for k, v in payload.items() if v}
response = self._post(payload)
# Check for context_length_exceeded error for single, long input message
if "error" in response and len(messages) == 1:
openai_single_message_context_length_exceeded(response["error"])
message = openai_response_to_message(response)
usage = self.get_usage(response)
return message, usage
@retry_procedure
def _post(self, payload: dict) -> dict:
# Note: While OpenAI and Ollama mount the API under "v1", this is
# conventional and not a strict requirement. For example, Azure OpenAI
# mounts the API under the deployment name, and "v1" is not in the URL.
# See https://github.com/openai/openai-openapi/blob/master/openapi.yaml
response = self.client.post("chat/completions", json=payload)
return raise_for_status(response).json()

View File

@@ -0,0 +1,185 @@
import base64
import json
import re
from typing import Any, Callable, Dict, List, Optional, Tuple
import httpx
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.tool import Tool
from tenacity import retry_if_exception
def retry_if_status(codes: Optional[List[int]] = None, above: Optional[int] = None) -> Callable:
codes = codes or []
def predicate(exc: Exception) -> bool:
if isinstance(exc, httpx.HTTPStatusError):
if exc.response.status_code in codes:
return True
if above and exc.response.status_code >= above:
return True
return False
return retry_if_exception(predicate)
def raise_for_status(response: httpx.Response) -> httpx.Response:
"""Raise with reason text."""
try:
response.raise_for_status()
return response
except httpx.HTTPStatusError as e:
response.read()
if response.text:
raise httpx.HTTPStatusError(f"{e}\n{response.text}", request=e.request, response=e.response)
else:
raise e
def encode_image(image_path: str) -> str:
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def messages_to_openai_spec(messages: List[Message]) -> List[Dict[str, Any]]:
messages_spec = []
for message in messages:
converted = {"role": message.role}
output = []
for content in message.content:
if isinstance(content, Text):
converted["content"] = content.text
elif isinstance(content, ToolUse):
sanitized_name = re.sub(r"[^a-zA-Z0-9_-]", "_", content.name)
converted.setdefault("tool_calls", []).append(
{
"id": content.id,
"type": "function",
"function": {
"name": sanitized_name,
"arguments": json.dumps(content.parameters),
},
}
)
elif isinstance(content, ToolResult):
if content.output.startswith('"image:'):
image_path = content.output.replace('"image:', "").replace('"', "")
output.append(
{
"role": "tool",
"content": [
{
"type": "text",
"text": "This tool result included an image that is uploaded in the next message.",
},
],
"tool_call_id": content.tool_use_id,
}
)
# Note: it is possible to only do this when message == messages[-1]
# but it doesn't seem to hurt too much with tokens to keep this.
output.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path)}"},
}
],
}
)
else:
output.append(
{
"role": "tool",
"content": content.output,
"tool_call_id": content.tool_use_id,
}
)
if "content" in converted or "tool_calls" in converted:
output = [converted] + output
messages_spec.extend(output)
return messages_spec
def tools_to_openai_spec(tools: Tuple[Tool]) -> Dict[str, Any]:
tools_names = set()
result = []
for tool in tools:
if tool.name in tools_names:
# we should never allow duplicate tools
raise ValueError(f"Duplicate tool name: {tool.name}")
result.append(
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
},
}
)
tools_names.add(tool.name)
return result
def openai_response_to_message(response: dict) -> Message:
original = response["choices"][0]["message"]
content = []
text = original.get("content")
if text:
content.append(Text(text=text))
tool_calls = original.get("tool_calls")
if tool_calls:
for tool_call in tool_calls:
try:
function_name = tool_call["function"]["name"]
# We occasionally see the model generate an invalid function name
# sending this back to openai raises a validation error
if not re.match(r"^[a-zA-Z0-9_-]+$", function_name):
content.append(
ToolUse(
id=tool_call["id"],
name=function_name,
parameters=tool_call["function"]["arguments"],
is_error=True,
error_message=f"The provided function name '{function_name}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+", # noqa: E501
)
)
else:
content.append(
ToolUse(
id=tool_call["id"],
name=function_name,
parameters=json.loads(tool_call["function"]["arguments"]),
)
)
except json.JSONDecodeError:
content.append(
ToolUse(
id=tool_call["id"],
name=tool_call["function"]["name"],
parameters=tool_call["function"]["arguments"],
is_error=True,
error_message=f"Could not interpret tool use parameters for id {tool_call['id']}: {tool_call['function']['arguments']}", # noqa: E501
)
)
return Message(role="assistant", content=content)
def openai_single_message_context_length_exceeded(error_dict: dict) -> None:
code = error_dict.get("code")
if code == "context_length_exceeded" or code == "string_above_max_length":
raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}")
class InitialMessageTooLargeError(Exception):
"""Custom error raised when the first input message in an exchange is too large."""
pass

View File

@@ -0,0 +1,27 @@
from collections import defaultdict
from typing import Dict
from exchange.providers.base import Usage
class _TokenUsageCollector:
def __init__(self) -> None:
self.usage_data = []
def collect(self, model: str, usage: Usage) -> None:
self.usage_data.append((model, usage))
def get_token_usage_group_by_model(self) -> Dict[str, Usage]:
usage_group_by_model = defaultdict(lambda: Usage(0, 0, 0))
for model, usage in self.usage_data:
usage_by_model = usage_group_by_model[model]
if usage is not None and usage.input_tokens is not None:
usage_by_model.input_tokens += usage.input_tokens
if usage is not None and usage.output_tokens is not None:
usage_by_model.output_tokens += usage.output_tokens
if usage is not None and usage.total_tokens is not None:
usage_by_model.total_tokens += usage.total_tokens
return usage_group_by_model
_token_usage_collector = _TokenUsageCollector()

View File

@@ -0,0 +1,55 @@
import inspect
from typing import Any, Callable, Type
from attrs import define
from exchange.utils import json_schema, parse_docstring
@define
class Tool:
"""A tool that can be used by a model.
Attributes:
name (str): The name of the tool
description (str): A description of what the tool does
parameters dict[str, Any]: A json schema of the function signature
function (Callable): The python function that powers the tool
"""
name: str
description: str
parameters: dict[str, Any]
function: Callable
@classmethod
def from_function(cls: Type["Tool"], func: Any) -> "Tool": # noqa: ANN401
"""Create a tool instance from a function and its docstring
The function must have a docstring - we require it to load the description
and parameter descriptions. This also supports a class instance with a __call__
method.
"""
if inspect.isfunction(func) or inspect.ismethod(func):
name = func.__name__
else:
name = func.__class__.__name__.lower()
func = func.__call__
description, param_descriptions = parse_docstring(func)
schema = json_schema(func)
# Set the 'description' field of the schema to the arg's docstring description
for arg in param_descriptions:
arg_name, arg_description = arg["name"], arg["description"]
if arg_name not in schema["properties"]:
raise ValueError(f"Argument {arg_name} found in docstring but not in schema")
schema["properties"][arg_name]["description"] = arg_description
return cls(
name=name,
description=description,
parameters=schema,
function=func,
)

View File

@@ -0,0 +1,155 @@
import inspect
import uuid
from importlib.metadata import entry_points
from typing import Any, Callable, Dict, List, Type, get_args, get_origin
from griffe import (
Docstring,
DocstringSection,
DocstringSectionParameters,
DocstringSectionText,
)
def create_object_id(prefix: str) -> str:
return f"{prefix}_{uuid.uuid4().hex[:24]}"
def compact(content: str) -> str:
"""Replace any amount of whitespace with a single space"""
return " ".join(content.split())
def parse_docstring(func: Callable) -> tuple[str, List[Dict]]:
"""Get description and parameters from function docstring"""
function_args = list(inspect.signature(func).parameters.keys())
text = str(func.__doc__)
docstring = Docstring(text)
for style in ["google", "numpy", "sphinx"]:
parsed = docstring.parse(style)
if not _check_section_is_present(parsed, DocstringSectionText):
continue
if function_args and not _check_section_is_present(parsed, DocstringSectionParameters):
continue
break
else: # if we did not find a valid style in the for loop
raise ValueError(
f"Attempted to load from a function {func.__name__} with an invalid docstring. Parameter docs are required if the function has parameters. https://mkdocstrings.github.io/griffe/reference/docstrings/#docstrings" # noqa: E501
)
description = None
parameters = []
for section in parsed:
if isinstance(section, DocstringSectionText):
description = compact(section.value)
elif isinstance(section, DocstringSectionParameters):
parameters = [arg.as_dict() for arg in section.value]
docstring_args = [d["name"] for d in parameters]
if description is None:
raise ValueError("Docstring must include a description.")
if not docstring_args == function_args:
extra_docstring_args = ", ".join(sorted(set(docstring_args) - set(function_args)))
extra_function_args = ", ".join(sorted(set(function_args) - set(docstring_args)))
if extra_docstring_args and extra_function_args:
raise ValueError(
f"Docstring args must match function args: docstring had extra {extra_docstring_args}; function had extra {extra_function_args}" # noqa: E501
)
elif extra_function_args:
raise ValueError(f"Docstring args must match function args: function had extra {extra_function_args}")
elif extra_docstring_args:
raise ValueError(f"Docstring args must match function args: docstring had extra {extra_docstring_args}")
else:
raise ValueError("Docstring args must match function args")
return description, parameters
def _check_section_is_present(
parsed_docstring: List[DocstringSection], section_type: Type[DocstringSectionText]
) -> bool:
for section in parsed_docstring:
if isinstance(section, section_type):
return True
return False
def json_schema(func: Any) -> dict[str, Any]: # noqa: ANN401
"""Get the json schema for a function"""
signature = inspect.signature(func)
parameters = signature.parameters
schema = {
"type": "object",
"properties": {},
"required": [],
}
for param_name, param in parameters.items():
param_schema = {}
if param.annotation is not inspect.Parameter.empty:
param_schema = _map_type_to_schema(param.annotation)
if param.default is not inspect.Parameter.empty:
param_schema["default"] = param.default
schema["properties"][param_name] = param_schema
if param.default is inspect.Parameter.empty:
schema["required"].append(param_name)
return schema
def _map_type_to_schema(py_type: Type) -> Dict[str, Any]: # noqa: ANN401
origin = get_origin(py_type)
args = get_args(py_type)
if origin is list or origin is tuple:
return {"type": "array", "items": _map_type_to_schema(args[0] if args else Any)}
elif origin is dict:
return {
"type": "object",
"additionalProperties": _map_type_to_schema(args[1] if len(args) > 1 else Any),
}
elif py_type is int:
return {"type": "integer"}
elif py_type is bool:
return {"type": "boolean"}
elif py_type is float:
return {"type": "number"}
elif py_type is str:
return {"type": "string"}
else:
return {"type": "string"}
def load_plugins(group: str) -> dict:
"""
Load plugins based on a specified entry point group.
This function iterates through all entry points registered under a specified group
Args:
group (str): The entry point group to load plugins from. This should match the group specified
in the package setup where plugins are defined.
Returns:
dict: A dictionary where each key is the entry point name, and the value is the loaded plugin object.
Raises:
Exception: Propagates exceptions raised by entry point loading, which might occur if a plugin
is not found or if there are issues with the plugin's code.
"""
plugins = {}
# Access all entry points for the specified group and load each.
for entrypoint in entry_points(group=group):
plugin = entrypoint.load() # Load the plugin.
plugins[entrypoint.name] = plugin # Store the loaded plugin in the dictionary.
return plugins

View File

@@ -0,0 +1,2 @@
lint.select = ["E", "W", "F", "N"]
line-length = 120

View File

@@ -0,0 +1 @@
"""Tests for exchange."""

View File

@@ -0,0 +1,36 @@
import pytest
from exchange.providers.base import Usage
@pytest.fixture
def dummy_tool():
def _dummy_tool() -> str:
"""An example tool"""
return "dummy response"
return _dummy_tool
@pytest.fixture
def usage_factory():
def _create_usage(input_tokens=100, output_tokens=200, total_tokens=300):
return Usage(input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=total_tokens)
return _create_usage
def read_file(filename: str) -> str:
"""
Read the contents of the file.
Args:
filename (str): The path to the file, which can be relative or
absolute. If it is a plain filename, it is assumed to be in the
current working directory.
Returns:
str: The contents of the file.
"""
assert filename == "test.txt"
return "hello exchange"

View File

@@ -0,0 +1 @@
"""Tests for chat completion providers."""

View File

@@ -0,0 +1,68 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"}], "model": "gpt-4o-mini"}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
api-key:
- test_azure_api_key
connection:
- keep-alive
content-length:
- '139'
content-type:
- application/json
host:
- test.openai.azure.com
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://test.openai.azure.com/openai/deployments/test-azure-deployment/chat/completions?api-version=2024-05-01-preview
response:
body:
string: '{"choices":[{"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}},"finish_reason":"stop","index":0,"logprobs":null,"message":{"content":"Hello!
How can I assist you today?","role":"assistant"}}],"created":1727230065,"id":"chatcmpl-ABBjN3AoYlxkP7Vg2lBvUhYeA6j5K","model":"gpt-4-32k","object":"chat.completion","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"system_fingerprint":null,"usage":{"completion_tokens":9,"prompt_tokens":18,"total_tokens":27}}
'
headers:
Cache-Control:
- no-cache, must-revalidate
Content-Length:
- '825'
Content-Type:
- application/json
Date:
- Wed, 25 Sep 2024 02:07:45 GMT
Set-Cookie: test_set_cookie
Strict-Transport-Security:
- max-age=31536000; includeSubDomains; preload
access-control-allow-origin:
- '*'
apim-request-id:
- 82e66ef8-ac07-4a43-b60f-9aecec1d8c81
azureml-model-session:
- d145-20240919052126
openai-organization: test_openai_org_key
x-accel-buffering:
- 'no'
x-content-type-options:
- nosniff
x-ms-client-request-id:
- 82e66ef8-ac07-4a43-b60f-9aecec1d8c81
x-ms-rai-invoked:
- 'true'
x-ms-region:
- Switzerland North
x-ratelimit-remaining-requests:
- '79'
x-ratelimit-remaining-tokens:
- '79984'
x-request-id:
- 38db9001-8b16-4efe-84c9-620e10f18c3c
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,74 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant.
Expect to need to read a file using read_file."}, {"role": "user", "content":
"What are the contents of this file? test.txt"}], "model": "gpt-4o-mini", "tools":
[{"type": "function", "function": {"name": "read_file", "description": "Read
the contents of the file.", "parameters": {"type": "object", "properties": {"filename":
{"type": "string", "description": "The path to the file, which can be relative
or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent
working directory."}}, "required": ["filename"]}}}]}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
api-key:
- test_azure_api_key
connection:
- keep-alive
content-length:
- '608'
content-type:
- application/json
host:
- test.openai.azure.com
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://test.openai.azure.com/openai/deployments/test-azure-deployment/chat/completions?api-version=2024-05-01-preview
response:
body:
string: '{"choices":[{"content_filter_results":{},"finish_reason":"tool_calls","index":0,"logprobs":null,"message":{"content":null,"role":"assistant","tool_calls":[{"function":{"arguments":"{\n \"filename\":
\"test.txt\"\n}","name":"read_file"},"id":"call_a47abadDxlGKIWjvYYvGVAHa","type":"function"}]}}],"created":1727256650,"id":"chatcmpl-ABIeABbq5WVCq0e0AriGFaYDSih3P","model":"gpt-4-32k","object":"chat.completion","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"system_fingerprint":null,"usage":{"completion_tokens":16,"prompt_tokens":109,"total_tokens":125}}
'
headers:
Cache-Control:
- no-cache, must-revalidate
Content-Length:
- '769'
Content-Type:
- application/json
Date:
- Wed, 25 Sep 2024 09:30:50 GMT
Set-Cookie: test_set_cookie
Strict-Transport-Security:
- max-age=31536000; includeSubDomains; preload
access-control-allow-origin:
- '*'
apim-request-id:
- 8c0e3372-8ffd-4ff5-a5d1-0b962c4ea339
azureml-model-session:
- d145-20240919052126
openai-organization: test_openai_org_key
x-accel-buffering:
- 'no'
x-content-type-options:
- nosniff
x-ms-client-request-id:
- 8c0e3372-8ffd-4ff5-a5d1-0b962c4ea339
x-ms-rai-invoked:
- 'true'
x-ms-region:
- Switzerland North
x-ratelimit-remaining-requests:
- '79'
x-ratelimit-remaining-tokens:
- '79824'
x-request-id:
- 401bd803-b790-47b7-b098-98708d44f060
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,68 @@
interactions:
- request:
body: ''
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
host:
- localhost:11434
user-agent:
- python-httpx/0.27.2
method: GET
uri: http://localhost:11434/
response:
body:
string: Ollama is running
headers:
Content-Length:
- '17'
Content-Type:
- text/plain; charset=utf-8
Date:
- Sun, 22 Sep 2024 23:40:13 GMT
Set-Cookie: test_set_cookie
openai-organization: test_openai_org_key
status:
code: 200
message: OK
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"}], "model": "mistral-nemo"}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '140'
content-type:
- application/json
host:
- localhost:11434
user-agent:
- python-httpx/0.27.2
method: POST
uri: http://localhost:11434/v1/chat/completions
response:
body:
string: "{\"id\":\"chatcmpl-429\",\"object\":\"chat.completion\",\"created\":1727048416,\"model\":\"mistral-nemo\",\"system_fingerprint\":\"fp_ollama\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"Hello!
I'm here to help. How can I assist you today? Let's chat. \U0001F60A\"},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":23,\"total_tokens\":33}}\n"
headers:
Content-Length:
- '356'
Content-Type:
- application/json
Date:
- Sun, 22 Sep 2024 23:40:16 GMT
Set-Cookie: test_set_cookie
openai-organization: test_openai_org_key
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,75 @@
interactions:
- request:
body: ''
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
host:
- localhost:11434
user-agent:
- python-httpx/0.27.2
method: GET
uri: http://localhost:11434/
response:
body:
string: Ollama is running
headers:
Content-Length:
- '17'
Content-Type:
- text/plain; charset=utf-8
Date:
- Wed, 25 Sep 2024 09:23:08 GMT
Set-Cookie: test_set_cookie
openai-organization: test_openai_org_key
status:
code: 200
message: OK
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant.
Expect to need to read a file using read_file."}, {"role": "user", "content":
"What are the contents of this file? test.txt"}], "model": "mistral-nemo", "tools":
[{"type": "function", "function": {"name": "read_file", "description": "Read
the contents of the file.", "parameters": {"type": "object", "properties": {"filename":
{"type": "string", "description": "The path to the file, which can be relative
or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent
working directory."}}, "required": ["filename"]}}}]}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '609'
content-type:
- application/json
host:
- localhost:11434
user-agent:
- python-httpx/0.27.2
method: POST
uri: http://localhost:11434/v1/chat/completions
response:
body:
string: '{"id":"chatcmpl-245","object":"chat.completion","created":1727256190,"model":"mistral-nemo","system_fingerprint":"fp_ollama","choices":[{"index":0,"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_z6fgu3z3","type":"function","function":{"name":"read_file","arguments":"{\"filename\":\"test.txt\"}"}}]},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":112,"completion_tokens":21,"total_tokens":133}}
'
headers:
Content-Length:
- '425'
Content-Type:
- application/json
Date:
- Wed, 25 Sep 2024 09:23:10 GMT
Set-Cookie: test_set_cookie
openai-organization: test_openai_org_key
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,80 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"}], "model": "gpt-4o-mini"}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
authorization:
- Bearer test_openai_api_key
connection:
- keep-alive
content-length:
- '139'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: "{\n \"id\": \"chatcmpl-AAQTYi3DXJnltAfd5sUH1Wnzh69t3\",\n \"object\":
\"chat.completion\",\n \"created\": 1727048416,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
\"assistant\",\n \"content\": \"Hello! How can I assist you today?\",\n
\ \"refusal\": null\n },\n \"logprobs\": null,\n \"finish_reason\":
\"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\": 18,\n \"completion_tokens\":
9,\n \"total_tokens\": 27,\n \"completion_tokens_details\": {\n \"reasoning_tokens\":
0\n }\n },\n \"system_fingerprint\": \"fp_1bb46167f9\"\n}\n"
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8c762399feb55739-SYD
Connection:
- keep-alive
Content-Type:
- application/json
Date:
- Sun, 22 Sep 2024 23:40:17 GMT
Server:
- cloudflare
Set-Cookie: test_set_cookie
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
content-length:
- '593'
openai-organization: test_openai_org_key
openai-processing-ms:
- '560'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=15552000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '10000'
x-ratelimit-limit-tokens:
- '200000'
x-ratelimit-remaining-requests:
- '9999'
x-ratelimit-remaining-tokens:
- '199973'
x-ratelimit-reset-requests:
- 8.64s
x-ratelimit-reset-tokens:
- 8ms
x-request-id:
- req_22e26c840219cde3152eaba1ce89483b
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,90 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant.
Expect to need to read a file using read_file."}, {"role": "user", "content":
"What are the contents of this file? test.txt"}], "model": "gpt-4o-mini", "tools":
[{"type": "function", "function": {"name": "read_file", "description": "Read
the contents of the file.", "parameters": {"type": "object", "properties": {"filename":
{"type": "string", "description": "The path to the file, which can be relative
or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent
working directory."}}, "required": ["filename"]}}}]}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
authorization:
- Bearer test_openai_api_key
connection:
- keep-alive
content-length:
- '608'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: "{\n \"id\": \"chatcmpl-ABIV2aZWVKQ774RAQ8KHYdNwkI5N7\",\n \"object\":
\"chat.completion\",\n \"created\": 1727256084,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
\"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n
\ \"id\": \"call_xXYlw4A7Ud1qtCopuK5gEJrP\",\n \"type\":
\"function\",\n \"function\": {\n \"name\": \"read_file\",\n
\ \"arguments\": \"{\\\"filename\\\":\\\"test.txt\\\"}\"\n }\n
\ }\n ],\n \"refusal\": null\n },\n \"logprobs\":
null,\n \"finish_reason\": \"tool_calls\"\n }\n ],\n \"usage\":
{\n \"prompt_tokens\": 107,\n \"completion_tokens\": 15,\n \"total_tokens\":
122,\n \"completion_tokens_details\": {\n \"reasoning_tokens\": 0\n
\ }\n },\n \"system_fingerprint\": \"fp_1bb46167f9\"\n}\n"
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8c89f19fed997e43-SYD
Connection:
- keep-alive
Content-Type:
- application/json
Date:
- Wed, 25 Sep 2024 09:21:25 GMT
Server:
- cloudflare
Set-Cookie: test_set_cookie
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
content-length:
- '844'
openai-organization: test_openai_org_key
openai-processing-ms:
- '266'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '10000'
x-ratelimit-limit-tokens:
- '200000'
x-ratelimit-remaining-requests:
- '9991'
x-ratelimit-remaining-tokens:
- '199952'
x-ratelimit-reset-requests:
- 1m9.486s
x-ratelimit-reset-tokens:
- 14ms
x-request-id:
- req_ff6b5d65c24f40e1faaf049c175e718d
status:
code: 200
message: OK
version: 1

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,131 @@
import os
import re
from typing import Type, Tuple
import pytest
from exchange import Message, ToolUse, ToolResult, Tool
from exchange.providers import Usage, Provider
from tests.conftest import read_file
OPENAI_API_KEY = "test_openai_api_key"
OPENAI_ORG_ID = "test_openai_org_key"
OPENAI_PROJECT_ID = "test_openai_project_id"
@pytest.fixture
def default_openai_env(monkeypatch):
"""
This fixture prevents OpenAIProvider.from_env() from erring on missing
environment variables.
When running VCR tests for the first time or after deleting a cassette
recording, set required environment variables, so that real requests don't
fail. Subsequent runs use the recorded data, so don't need them.
"""
if "OPENAI_API_KEY" not in os.environ:
monkeypatch.setenv("OPENAI_API_KEY", OPENAI_API_KEY)
AZURE_ENDPOINT = "https://test.openai.azure.com"
AZURE_DEPLOYMENT_NAME = "test-azure-deployment"
AZURE_API_VERSION = "2024-05-01-preview"
AZURE_API_KEY = "test_azure_api_key"
@pytest.fixture
def default_azure_env(monkeypatch):
"""
This fixture prevents AzureProvider.from_env() from erring on missing
environment variables.
When running VCR tests for the first time or after deleting a cassette
recording, set required environment variables, so that real requests don't
fail. Subsequent runs use the recorded data, so don't need them.
"""
if "AZURE_CHAT_COMPLETIONS_HOST_NAME" not in os.environ:
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_HOST_NAME", AZURE_ENDPOINT)
if "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME" not in os.environ:
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", AZURE_DEPLOYMENT_NAME)
if "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION" not in os.environ:
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", AZURE_API_VERSION)
if "AZURE_CHAT_COMPLETIONS_KEY" not in os.environ:
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_KEY", AZURE_API_KEY)
@pytest.fixture(scope="module")
def vcr_config():
"""
This scrubs sensitive data and gunzips bodies when in recording mode.
Without this, you would leak cookies and auth tokens in the cassettes.
Also, depending on the request, some responses would be binary encoded
while others plain json. This ensures all bodies are human-readable.
"""
return {
"decode_compressed_response": True,
"filter_headers": [
("authorization", "Bearer " + OPENAI_API_KEY),
("openai-organization", OPENAI_ORG_ID),
("openai-project", OPENAI_PROJECT_ID),
("cookie", None),
],
"before_record_request": scrub_request_url,
"before_record_response": scrub_response_headers,
}
def scrub_request_url(request):
"""
This scrubs sensitive request data in provider-specific way. Note that headers
are case-sensitive!
"""
if "openai.azure.com" in request.uri:
request.uri = re.sub(r"https://[^/]+", AZURE_ENDPOINT, request.uri)
request.uri = re.sub(r"/deployments/[^/]+", f"/deployments/{AZURE_DEPLOYMENT_NAME}", request.uri)
request.headers["host"] = AZURE_ENDPOINT.replace("https://", "")
request.headers["api-key"] = AZURE_API_KEY
return request
def scrub_response_headers(response):
"""
This scrubs sensitive response headers. Note they are case-sensitive!
"""
response["headers"]["openai-organization"] = OPENAI_ORG_ID
response["headers"]["Set-Cookie"] = "test_set_cookie"
return response
def complete(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
provider = provider_cls.from_env()
system = "You are a helpful assistant."
messages = [Message.user("Hello")]
return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs)
def tools(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
provider = provider_cls.from_env()
system = "You are a helpful assistant. Expect to need to read a file using read_file."
messages = [Message.user("What are the contents of this file? test.txt")]
return provider.complete(
model=model, system=system, messages=messages, tools=(Tool.from_function(read_file),), **kwargs
)
def vision(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
provider = provider_cls.from_env()
system = "You are a helpful assistant."
messages = [
Message.user("What does the first entry in the menu say?"),
Message(
role="assistant",
content=[ToolUse(id="xyz", name="screenshot", parameters={})],
),
Message(
role="user",
content=[ToolResult(tool_use_id="xyz", output='"image:tests/test_image.png"')],
),
]
return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs)

View File

@@ -0,0 +1,174 @@
import os
from unittest.mock import patch
import httpx
import pytest
from exchange import Message, Text
from exchange.content import ToolResult, ToolUse
from exchange.providers.anthropic import AnthropicProvider
from exchange.tool import Tool
def example_fn(param: str) -> None:
"""
Testing function.
Args:
param (str): Description of param1
"""
pass
@pytest.fixture
@patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test_api_key"})
def anthropic_provider():
return AnthropicProvider.from_env()
def test_anthropic_response_to_text_message() -> None:
response = {
"content": [{"type": "text", "text": "Hello from Claude!"}],
}
message = AnthropicProvider.anthropic_response_to_message(response)
assert message.content[0].text == "Hello from Claude!"
def test_anthropic_response_to_tool_use_message() -> None:
response = {
"content": [
{
"type": "tool_use",
"id": "1",
"name": "example_fn",
"input": {"param": "value"},
}
],
}
message = AnthropicProvider.anthropic_response_to_message(response)
assert message.content[0].id == "1"
assert message.content[0].name == "example_fn"
assert message.content[0].parameters == {"param": "value"}
def test_tools_to_anthropic_spec() -> None:
tools = (Tool.from_function(example_fn),)
expected_spec = [
{
"name": "example_fn",
"description": "Testing function.",
"input_schema": {
"type": "object",
"properties": {"param": {"type": "string", "description": "Description of param1"}},
"required": ["param"],
},
}
]
result = AnthropicProvider.tools_to_anthropic_spec(tools)
assert result == expected_spec
def test_message_text_to_anthropic_spec() -> None:
messages = [Message.user("Hello, Claude")]
expected_spec = [
{
"role": "user",
"content": [{"type": "text", "text": "Hello, Claude"}],
}
]
result = AnthropicProvider.messages_to_anthropic_spec(messages)
assert result == expected_spec
def test_messages_to_anthropic_spec() -> None:
messages = [
Message(role="user", content=[Text(text="Hello, Claude")]),
Message(
role="assistant",
content=[ToolUse(id="1", name="example_fn", parameters={"param": "value"})],
),
Message(role="user", content=[ToolResult(tool_use_id="1", output="Result")]),
]
actual_spec = AnthropicProvider.messages_to_anthropic_spec(messages)
# !=
expected_spec = [
{"role": "user", "content": [{"type": "text", "text": "Hello, Claude"}]},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "1",
"name": "example_fn",
"input": {"param": "value"},
}
],
},
{
"role": "user",
"content": [{"type": "tool_result", "tool_use_id": "1", "content": "Result"}],
},
]
assert actual_spec == expected_spec
@patch("httpx.Client.post")
@patch("logging.warning")
@patch("logging.error")
def test_anthropic_completion(mock_error, mock_warning, mock_post, anthropic_provider):
mock_response = {
"content": [{"type": "text", "text": "Hello from Claude!"}],
"usage": {"input_tokens": 10, "output_tokens": 25},
}
# First attempts fail with status code 429, 2nd succeeds
def create_response(status_code, json_data=None):
response = httpx.Response(status_code)
response._content = httpx._content.json_dumps(json_data or {}).encode()
response._request = httpx.Request("POST", "https://api.anthropic.com/v1/messages")
return response
mock_post.side_effect = [
create_response(429), # 1st attempt
create_response(200, mock_response), # Final success
]
model = "claude-3-5-sonnet-20240620"
system = "You are a helpful assistant."
messages = [Message.user("Hello, Claude")]
reply_message, reply_usage = anthropic_provider.complete(model=model, system=system, messages=messages)
assert reply_message.content == [Text(text="Hello from Claude!")]
assert reply_usage.total_tokens == 35
assert mock_post.call_count == 2
mock_post.assert_any_call(
"https://api.anthropic.com/v1/messages",
json={
"system": system,
"model": model,
"max_tokens": 4096,
"messages": [
*[
{
"role": msg.role,
"content": [{"type": "text", "text": msg.content[0].text}],
}
for msg in messages
],
],
},
)
@pytest.mark.integration
def test_anthropic_integration():
provider = AnthropicProvider.from_env()
model = "claude-3-5-sonnet-20240620" # updated model to a known valid model
system = "You are a helpful assistant."
messages = [Message.user("Hello, Claude")]
# Run the completion
reply = provider.complete(model=model, system=system, messages=messages)
assert reply[0].content is not None
print("Completion content from Anthropic:", reply[0].content)

View File

@@ -0,0 +1,48 @@
import os
import pytest
from exchange import Text, ToolUse
from exchange.providers.azure import AzureProvider
from .conftest import complete, tools
AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini")
@pytest.mark.vcr()
def test_azure_complete(default_azure_env):
reply_message, reply_usage = complete(AzureProvider, AZURE_MODEL)
assert reply_message.content == [Text(text="Hello! How can I assist you today?")]
assert reply_usage.total_tokens == 27
@pytest.mark.integration
def test_azure_complete_integration():
reply = complete(AzureProvider, AZURE_MODEL)
assert reply[0].content is not None
print("Completion content from Azure:", reply[0].content)
@pytest.mark.vcr()
def test_azure_tools(default_azure_env):
reply_message, reply_usage = tools(AzureProvider, AZURE_MODEL)
tool_use = reply_message.content[0]
assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}"
assert tool_use.id == "call_a47abadDxlGKIWjvYYvGVAHa"
assert tool_use.name == "read_file"
assert tool_use.parameters == {"filename": "test.txt"}
assert reply_usage.total_tokens == 125
@pytest.mark.integration
def test_azure_tools_integration():
reply = tools(AzureProvider, AZURE_MODEL)
tool_use = reply[0].content[0]
assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}"
assert tool_use.id is not None
assert tool_use.name == "read_file"
assert tool_use.parameters == {"filename": "test.txt"}

View File

@@ -0,0 +1,228 @@
import logging
import os
from unittest.mock import patch
import pytest
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.providers.bedrock import BedrockProvider
from exchange.tool import Tool
logger = logging.getLogger(__name__)
@pytest.fixture
@patch.dict(
os.environ,
{
"AWS_REGION": "us-east-1",
"AWS_ACCESS_KEY_ID": "fake-access-key",
"AWS_SECRET_ACCESS_KEY": "fake-secret-key",
"AWS_SESSION_TOKEN": "fake-session-token",
},
)
def bedrock_provider():
return BedrockProvider.from_env()
@patch("time.time", return_value=1624250000)
def test_sign_and_get_headers(mock_time, bedrock_provider):
# Create sample values
method = "POST"
url = "https://bedrock-runtime.us-east-1.amazonaws.com/some/path"
payload = {"key": "value"}
service = "bedrock"
# Generate headers
headers = bedrock_provider.client.sign_and_get_headers(
method,
url,
payload,
service,
)
# Assert that headers contain expected keys
assert "Authorization" in headers
assert "Content-Type" in headers
assert "X-Amz-date" in headers
assert "x-amz-content-sha256" in headers
assert "X-Amz-Security-Token" in headers
@patch("httpx.Client.post")
def test_complete(mock_post, bedrock_provider):
# Mocked response from the server
mock_response = {
"output": {"message": {"role": "assistant", "content": [{"text": "Hello, world!"}]}},
"usage": {"inputTokens": 10, "outputTokens": 15, "totalTokens": 25},
}
mock_post.return_value.json.return_value = mock_response
model = "test-model"
system = "You are a helpful assistant."
messages = [Message.user("Hello")]
tools = ()
reply_message, reply_usage = bedrock_provider.complete(model=model, system=system, messages=messages, tools=tools)
# Assertions for reply message
assert reply_message.content[0].text == "Hello, world!"
assert reply_usage.total_tokens == 25
def test_message_to_bedrock_spec_text(bedrock_provider):
message = Message(role="user", content=[Text("Hello, world!")])
expected = {"role": "user", "content": [{"text": "Hello, world!"}]}
assert bedrock_provider.message_to_bedrock_spec(message) == expected
def test_message_to_bedrock_spec_tool_use(bedrock_provider):
tool_use = ToolUse(id="tool-1", name="WordCount", parameters={"text": "Hello, world!"})
message = Message(role="assistant", content=[tool_use])
expected = {
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "tool-1",
"name": "WordCount",
"input": {"text": "Hello, world!"},
}
}
],
}
assert bedrock_provider.message_to_bedrock_spec(message) == expected
def test_message_to_bedrock_spec_tool_result(bedrock_provider):
message = Message(
role="assistant",
content=[ToolUse(id="tool-1", name="WordCount", parameters={"text": "Hello, world!"})],
)
expected = {
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "tool-1",
"name": "WordCount",
"input": {"text": "Hello, world!"},
}
}
],
}
assert bedrock_provider.message_to_bedrock_spec(message) == expected
def test_message_to_bedrock_spec_tool_result_text(bedrock_provider):
tool_result = ToolResult(tool_use_id="tool-1", output="Error occurred", is_error=True)
message = Message(role="user", content=[tool_result])
expected = {
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": "tool-1",
"content": [{"text": "Error occurred"}],
"status": "error",
}
}
],
}
assert bedrock_provider.message_to_bedrock_spec(message) == expected
def test_message_to_bedrock_spec_invalid(bedrock_provider):
with pytest.raises(Exception):
bedrock_provider.message_to_bedrock_spec(Message(role="user", content=[]))
def test_response_to_message_text(bedrock_provider):
response = {"role": "user", "content": [{"text": "Hello, world!"}]}
message = bedrock_provider.response_to_message(response)
assert message.role == "user"
assert message.content[0].text == "Hello, world!"
def test_response_to_message_tool_use(bedrock_provider):
response = {
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "tool-1",
"name": "WordCount",
"input": {"text": "Hello, world!"},
}
}
],
}
message = bedrock_provider.response_to_message(response)
assert message.role == "assistant"
assert message.content[0].name == "WordCount"
assert message.content[0].parameters == {"text": "Hello, world!"}
def test_response_to_message_tool_result(bedrock_provider):
response = {
"role": "user",
"content": [
{
"toolResult": {
"toolResultId": "tool-1",
"content": [{"json": {"result": 2}}],
}
}
],
}
message = bedrock_provider.response_to_message(response)
assert message.role == "user"
assert message.content[0].tool_use_id == "tool-1"
assert message.content[0].output == {"result": 2}
def test_response_to_message_invalid(bedrock_provider):
with pytest.raises(Exception):
bedrock_provider.response_to_message({})
def test_tools_to_bedrock_spec(bedrock_provider):
def word_count(text: str):
return len(text.split())
tool = Tool(
name="WordCount",
description="Counts words.",
parameters={"text": "string"},
function=word_count,
)
expected = {
"tools": [
{
"toolSpec": {
"name": "WordCount",
"description": "Counts words.",
"inputSchema": {"json": {"text": "string"}},
}
}
]
}
assert bedrock_provider.tools_to_bedrock_spec((tool,)) == expected
def test_tools_to_bedrock_spec_duplicate(bedrock_provider):
def word_count(text: str):
return len(text.split())
tool = Tool(
name="WordCount",
description="Counts words.",
parameters={"text": "string"},
function=word_count,
)
tool_duplicate = Tool(
name="WordCount",
description="Counts words.",
parameters={"text": "string"},
function=word_count,
)
tools = bedrock_provider.tools_to_bedrock_spec((tool, tool_duplicate))
assert set(tool["toolSpec"]["name"] for tool in tools["tools"]) == {"WordCount"}

View File

@@ -0,0 +1,49 @@
import os
from unittest.mock import patch
import pytest
from exchange import Message, Text
from exchange.providers.databricks import DatabricksProvider
@pytest.fixture
@patch.dict(
os.environ,
{"DATABRICKS_HOST": "http://test-host", "DATABRICKS_TOKEN": "test_token"},
)
def databricks_provider():
return DatabricksProvider.from_env()
@patch("httpx.Client.post")
@patch("time.sleep", return_value=None)
@patch("logging.warning")
@patch("logging.error")
def test_databricks_completion(mock_error, mock_warning, mock_sleep, mock_post, databricks_provider):
mock_response = {
"choices": [{"message": {"role": "assistant", "content": "Hello!"}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35},
}
mock_post.return_value.json.return_value = mock_response
model = "my-databricks-model"
system = "You are a helpful assistant."
messages = [Message.user("Hello")]
tools = ()
reply_message, reply_usage = databricks_provider.complete(
model=model, system=system, messages=messages, tools=tools
)
assert reply_message.content == [Text(text="Hello!")]
assert reply_usage.total_tokens == 35
assert mock_post.call_count == 1
mock_post.assert_called_once_with(
"serving-endpoints/my-databricks-model/invocations",
json={
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": "Hello"},
]
},
)

View File

@@ -0,0 +1,147 @@
import os
from unittest.mock import patch
import httpx
import pytest
from exchange import Message, Text
from exchange.content import ToolResult, ToolUse
from exchange.providers.google import GoogleProvider
from exchange.tool import Tool
def example_fn(param: str) -> None:
"""
Testing function.
Args:
param (str): Description of param1
"""
pass
@pytest.fixture
@patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"})
def google_provider():
return GoogleProvider.from_env()
def test_google_response_to_text_message() -> None:
response = {"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}]}
message = GoogleProvider.google_response_to_message(response)
assert message.content[0].text == "Hello from Gemini!"
def test_google_response_to_tool_use_message() -> None:
response = {
"candidates": [
{
"content": {
"parts": [{"functionCall": {"name": "example_fn", "args": {"param": "value"}}}],
"role": "model",
}
}
]
}
message = GoogleProvider.google_response_to_message(response)
assert message.content[0].name == "example_fn"
assert message.content[0].parameters == {"param": "value"}
def test_tools_to_google_spec() -> None:
tools = (Tool.from_function(example_fn),)
expected_spec = {
"functionDeclarations": [
{
"name": "example_fn",
"description": "Testing function.",
"parameters": {
"type": "object",
"properties": {"param": {"type": "string", "description": "Description of param1"}},
"required": ["param"],
},
}
]
}
result = GoogleProvider.tools_to_google_spec(tools)
assert result == expected_spec
def test_message_text_to_google_spec() -> None:
messages = [Message.user("Hello, Gemini")]
expected_spec = [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}]
result = GoogleProvider.messages_to_google_spec(messages)
assert result == expected_spec
def test_messages_to_google_spec() -> None:
messages = [
Message(role="user", content=[Text(text="Hello, Gemini")]),
Message(
role="assistant",
content=[ToolUse(id="1", name="example_fn", parameters={"param": "value"})],
),
Message(role="user", content=[ToolResult(tool_use_id="1", output="Result")]),
]
actual_spec = GoogleProvider.messages_to_google_spec(messages)
# !=
expected_spec = [
{"role": "user", "parts": [{"text": "Hello, Gemini"}]},
{"role": "model", "parts": [{"functionCall": {"name": "example_fn", "args": {"param": "value"}}}]},
{"role": "user", "parts": [{"functionResponse": {"name": "1", "response": {"content": "Result"}}}]},
]
assert actual_spec == expected_spec
@patch("httpx.Client.post")
@patch("logging.warning")
@patch("logging.error")
def test_google_completion(mock_error, mock_warning, mock_post, google_provider):
mock_response = {
"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}],
"usageMetadata": {"promptTokenCount": 3, "candidatesTokenCount": 10, "totalTokenCount": 13},
}
# First attempts fail with status code 429, 2nd succeeds
def create_response(status_code, json_data=None):
response = httpx.Response(status_code)
response._content = httpx._content.json_dumps(json_data or {}).encode()
response._request = httpx.Request("POST", "https://generativelanguage.googleapis.com/v1beta/")
return response
mock_post.side_effect = [
create_response(429), # 1st attempt
create_response(200, mock_response), # Final success
]
model = "gemini-1.5-flash"
system = "You are a helpful assistant."
messages = [Message.user("Hello, Gemini")]
reply_message, reply_usage = google_provider.complete(model=model, system=system, messages=messages)
assert reply_message.content == [Text(text="Hello from Gemini!")]
assert reply_usage.total_tokens == 13
assert mock_post.call_count == 2
mock_post.assert_any_call(
"models/gemini-1.5-flash:generateContent",
json={
"system_instruction": {"parts": [{"text": "You are a helpful assistant."}]},
"contents": [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}],
},
)
@pytest.mark.integration
def test_google_integration():
provider = GoogleProvider.from_env()
model = "gemini-1.5-flash" # updated model to a known valid model
system = "You are a helpful assistant."
messages = [Message.user("Hello, Gemini")]
# Run the completion
reply = provider.complete(model=model, system=system, messages=messages)
assert reply[0].content is not None
print("Completion content from Google:", reply[0].content)

View File

@@ -0,0 +1,48 @@
import os
import pytest
from exchange import Text, ToolUse
from exchange.providers.ollama import OllamaProvider, OLLAMA_MODEL
from .conftest import complete, tools
OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", OLLAMA_MODEL)
@pytest.mark.vcr()
def test_ollama_complete():
reply_message, reply_usage = complete(OllamaProvider, OLLAMA_MODEL)
assert reply_message.content == [Text(text="Hello! I'm here to help. How can I assist you today? Let's chat. 😊")]
assert reply_usage.total_tokens == 33
@pytest.mark.integration
def test_ollama_complete_integration():
reply = complete(OllamaProvider, OLLAMA_MODEL)
assert reply[0].content is not None
print("Completion content from OpenAI:", reply[0].content)
@pytest.mark.vcr()
def test_ollama_tools():
reply_message, reply_usage = tools(OllamaProvider, OLLAMA_MODEL)
tool_use = reply_message.content[0]
assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}"
assert tool_use.id == "call_z6fgu3z3"
assert tool_use.name == "read_file"
assert tool_use.parameters == {"filename": "test.txt"}
assert reply_usage.total_tokens == 133
@pytest.mark.integration
def test_ollama_tools_integration():
reply = tools(OllamaProvider, OLLAMA_MODEL, seed=3)
tool_use = reply[0].content[0]
assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}"
assert tool_use.id is not None
assert tool_use.name == "read_file"
assert tool_use.parameters == {"filename": "test.txt"}

View File

@@ -0,0 +1,63 @@
import os
import pytest
from exchange import Text, ToolUse
from exchange.providers.openai import OpenAiProvider
from .conftest import complete, vision, tools
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
@pytest.mark.vcr()
def test_openai_complete(default_openai_env):
reply_message, reply_usage = complete(OpenAiProvider, OPENAI_MODEL)
assert reply_message.content == [Text(text="Hello! How can I assist you today?")]
assert reply_usage.total_tokens == 27
@pytest.mark.integration
def test_openai_complete_integration():
reply = complete(OpenAiProvider, OPENAI_MODEL)
assert reply[0].content is not None
print("Completion content from OpenAI:", reply[0].content)
@pytest.mark.vcr()
def test_openai_tools(default_openai_env):
reply_message, reply_usage = tools(OpenAiProvider, OPENAI_MODEL)
tool_use = reply_message.content[0]
assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}"
assert tool_use.id == "call_xXYlw4A7Ud1qtCopuK5gEJrP"
assert tool_use.name == "read_file"
assert tool_use.parameters == {"filename": "test.txt"}
assert reply_usage.total_tokens == 122
@pytest.mark.integration
def test_openai_tools_integration():
reply = tools(OpenAiProvider, OPENAI_MODEL)
tool_use = reply[0].content[0]
assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}"
assert tool_use.id is not None
assert tool_use.name == "read_file"
assert tool_use.parameters == {"filename": "test.txt"}
@pytest.mark.vcr()
def test_openai_vision(default_openai_env):
reply_message, reply_usage = vision(OpenAiProvider, OPENAI_MODEL)
assert reply_message.content == [Text(text='The first entry in the menu says "Ask Goose."')]
assert reply_usage.total_tokens == 14241
@pytest.mark.integration
def test_openai_vision_integration():
reply = vision(OpenAiProvider, OPENAI_MODEL)
assert "ask goose" in reply[0].text.lower()

View File

@@ -0,0 +1,245 @@
from copy import deepcopy
import json
from unittest.mock import Mock
from attrs import asdict
import httpx
import pytest
from unittest.mock import patch
from exchange.content import Text, ToolResult, ToolUse
from exchange.message import Message
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
raise_for_status,
tools_to_openai_spec,
)
from exchange.tool import Tool
OPEN_AI_TOOL_USE_RESPONSE = response = {
"choices": [
{
"role": "assistant",
"message": {
"tool_calls": [
{
"id": "1",
"function": {
"name": "example_fn",
"arguments": json.dumps(
{
"param": "value",
}
),
# TODO: should this handle dict's as well?
},
}
],
},
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 25,
"total_tokens": 35,
},
}
def example_fn(param: str) -> None:
"""
Testing function.
Args:
param (str): Description of param1
"""
pass
def example_fn_two() -> str:
"""
Second testing function
Returns:
str: Description of return value
"""
pass
def test_raise_for_status_success() -> None:
response = Mock(spec=httpx.Response)
response.status_code = 200
result = raise_for_status(response)
assert result == response
def test_raise_for_status_failure_with_text() -> None:
response = Mock(spec=httpx.Response)
response.status_code = 404
response.text = "Not Found: John Cena"
try:
raise_for_status(response)
except httpx.HTTPStatusError as e:
assert e.response == response
assert str(e) == "404 Not Found: John Cena"
assert e.request is None
def test_raise_for_status_failure_without_text() -> None:
response = Mock(spec=httpx.Response)
response.status_code = 500
response.text = ""
try:
raise_for_status(response)
except httpx.HTTPStatusError as e:
assert e.response == response
assert str(e) == "500 Internal Server Error"
assert e.request is None
def test_messages_to_openai_spec() -> None:
messages = [
Message(role="assistant", content=[Text("Hello!")]),
Message(role="user", content=[Text("How are you?")]),
Message(
role="assistant",
content=[ToolUse(id=1, name="tool1", parameters={"param1": "value1"})],
),
Message(role="user", content=[ToolResult(tool_use_id=1, output="Result")]),
]
spec = messages_to_openai_spec(messages)
assert spec == [
{"role": "assistant", "content": "Hello!"},
{"role": "user", "content": "How are you?"},
{
"role": "assistant",
"tool_calls": [
{
"id": 1,
"type": "function",
"function": {
"name": "tool1",
"arguments": '{"param1": "value1"}',
},
}
],
},
{
"role": "tool",
"content": "Result",
"tool_call_id": 1,
},
]
def test_tools_to_openai_spec() -> None:
tools = (Tool.from_function(example_fn), Tool.from_function(example_fn_two))
assert len(tools_to_openai_spec(tools)) == 2
def test_tools_to_openai_spec_duplicate() -> None:
tools = (Tool.from_function(example_fn), Tool.from_function(example_fn))
with pytest.raises(ValueError):
tools_to_openai_spec(tools)
def test_tools_to_openai_spec_single() -> None:
tools = Tool.from_function(example_fn)
expected_spec = [
{
"type": "function",
"function": {
"name": "example_fn",
"description": "Testing function.",
"parameters": {
"type": "object",
"properties": {
"param": {
"type": "string",
"description": "Description of param1",
}
},
"required": ["param"],
},
},
},
]
result = tools_to_openai_spec((tools,))
assert result == expected_spec
def test_tools_to_openai_spec_empty() -> None:
tools = ()
expected_spec = []
assert tools_to_openai_spec(tools) == expected_spec
def test_openai_response_to_message_text() -> None:
response = {
"choices": [
{
"role": "assistant",
"message": {"content": "Hello from John Cena!"},
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 25,
"total_tokens": 35,
},
}
message = openai_response_to_message(response)
actual = asdict(message)
expect = asdict(
Message(
role="assistant",
content=[Text("Hello from John Cena!")],
)
)
actual.pop("id")
expect.pop("id")
assert actual == expect
def test_openai_response_to_message_valid_tooluse() -> None:
response = deepcopy(OPEN_AI_TOOL_USE_RESPONSE)
message = openai_response_to_message(response)
actual = asdict(message)
expect = asdict(
Message(
role="assistant",
content=[ToolUse(id=1, name="example_fn", parameters={"param": "value"})],
)
)
actual.pop("id")
actual["content"][0].pop("id")
expect.pop("id")
expect["content"][0].pop("id")
assert actual == expect
def test_openai_response_to_message_invalid_func_name() -> None:
response = deepcopy(OPEN_AI_TOOL_USE_RESPONSE)
response["choices"][0]["message"]["tool_calls"][0]["function"]["name"] = "invalid fn"
message = openai_response_to_message(response)
assert message.content[0].name == "invalid fn"
assert json.loads(message.content[0].parameters) == {"param": "value"}
assert message.content[0].is_error
assert message.content[0].error_message.startswith("The provided function name")
@patch("json.loads", side_effect=json.JSONDecodeError("error", "doc", 0))
def test_openai_response_to_message_json_decode_error(mock_json) -> None:
response = deepcopy(OPEN_AI_TOOL_USE_RESPONSE)
message = openai_response_to_message(response)
assert message.content[0].name == "example_fn"
assert message.content[0].is_error
assert message.content[0].error_message.startswith("Could not interpret tool use")

View File

@@ -0,0 +1,763 @@
from typing import List, Tuple
import pytest
from exchange.checkpoint import Checkpoint, CheckpointData
from exchange.content import Text, ToolResult, ToolUse
from exchange.exchange import Exchange
from exchange.message import Message
from exchange.moderators import PassiveModerator
from exchange.providers import Provider, Usage
from exchange.tool import Tool
def dummy_tool() -> str:
"""An example tool"""
return "dummy response"
too_long_output = "x" * (2**20 + 1)
too_long_token_output = "word " * 128000
def no_overlapping_checkpoints(exchange: Exchange) -> bool:
"""Assert that there are no overlapping checkpoints in the exchange."""
for i, checkpoint in enumerate(exchange.checkpoint_data.checkpoints):
for other_checkpoint in exchange.checkpoint_data.checkpoints[i + 1 :]:
if not checkpoint.end_index < other_checkpoint.start_index:
return False
return True
def checkpoint_to_index_pairs(checkpoints: List[Checkpoint]) -> List[Tuple[int, int]]:
return [(checkpoint.start_index, checkpoint.end_index) for checkpoint in checkpoints]
class MockProvider(Provider):
def __init__(self, sequence: List[Message], usage_dicts: List[dict]):
# We'll use init to provide a preplanned reply sequence
self.sequence = sequence
self.call_count = 0
self.usage_dicts = usage_dicts
@staticmethod
def get_usage(data: dict) -> Usage:
usage = data.pop("usage")
input_tokens = usage.get("input_tokens")
output_tokens = usage.get("output_tokens")
total_tokens = usage.get("total_tokens")
if total_tokens is None and input_tokens is not None and output_tokens is not None:
total_tokens = input_tokens + output_tokens
return Usage(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
)
def complete(self, model: str, system: str, messages: List[Message], tools: List[Tool]) -> Message:
output = self.sequence[self.call_count]
usage = self.get_usage(self.usage_dicts[self.call_count])
self.call_count += 1
return (output, usage)
def test_reply_with_unsupported_tool():
ex = Exchange(
provider=MockProvider(
sequence=[
Message(
role="assistant",
content=[ToolUse(id="1", name="unsupported_tool", parameters={})],
),
Message(
role="assistant",
content=[Text(text="Here is the completion after tool call")],
),
],
usage_dicts=[
{"usage": {"input_tokens": 12, "output_tokens": 23}},
{"usage": {"input_tokens": 12, "output_tokens": 23}},
],
),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
tools=(Tool.from_function(dummy_tool),),
moderator=PassiveModerator(),
)
ex.add(Message(role="user", content=[Text(text="test")]))
ex.reply()
content = ex.messages[-2].content[0]
assert isinstance(content, ToolResult) and content.is_error and "no tool exists" in content.output.lower()
def test_invalid_tool_parameters():
"""Test handling of invalid tool parameters response"""
ex = Exchange(
provider=MockProvider(
sequence=[
Message(
role="assistant",
content=[ToolUse(id="1", name="dummy_tool", parameters="invalid json")],
),
Message(
role="assistant",
content=[Text(text="Here is the completion after tool call")],
),
],
usage_dicts=[
{"usage": {"input_tokens": 12, "output_tokens": 23}},
{"usage": {"input_tokens": 12, "output_tokens": 23}},
],
),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
tools=[Tool.from_function(dummy_tool)],
moderator=PassiveModerator(),
)
ex.add(Message(role="user", content=[Text(text="test invalid parameters")]))
ex.reply()
content = ex.messages[-2].content[0]
assert isinstance(content, ToolResult) and content.is_error and "invalid json" in content.output.lower()
def test_max_tool_use_when_limit_reached():
"""Test the max_tool_use parameter in the reply method."""
ex = Exchange(
provider=MockProvider(
sequence=[
Message(
role="assistant",
content=[ToolUse(id="1", name="dummy_tool", parameters={})],
),
Message(
role="assistant",
content=[ToolUse(id="2", name="dummy_tool", parameters={})],
),
Message(
role="assistant",
content=[ToolUse(id="3", name="dummy_tool", parameters={})],
),
],
usage_dicts=[
{"usage": {"input_tokens": 12, "output_tokens": 23}},
{"usage": {"input_tokens": 12, "output_tokens": 23}},
{"usage": {"input_tokens": 12, "output_tokens": 23}},
],
),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
tools=[Tool.from_function(dummy_tool)],
moderator=PassiveModerator(),
)
ex.add(Message(role="user", content=[Text(text="test max tool use")]))
response = ex.reply(max_tool_use=3)
assert ex.provider.call_count == 3
assert "reached the limit" in response.content[0].text.lower()
assert isinstance(ex.messages[-2].content[0], ToolResult) and ex.messages[-2].content[0].tool_use_id == "3"
assert ex.messages[-1].role == "assistant"
def test_tool_output_too_long_character_error():
"""Test tool handling when output exceeds character limit."""
def long_output_tool_char() -> str:
return too_long_output
ex = Exchange(
provider=MockProvider(
sequence=[
Message(
role="assistant",
content=[ToolUse(id="1", name="long_output_tool_char", parameters={})],
),
Message(
role="assistant",
content=[Text(text="Here is the completion after tool call")],
),
],
usage_dicts=[
{"usage": {"input_tokens": 12, "output_tokens": 23}},
{"usage": {"input_tokens": 12, "output_tokens": 23}},
],
),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
tools=[Tool.from_function(long_output_tool_char)],
moderator=PassiveModerator(),
)
ex.add(Message(role="user", content=[Text(text="test long output char")]))
ex.reply()
content = ex.messages[-2].content[0]
assert (
isinstance(content, ToolResult)
and content.is_error
and "output that was too long to handle" in content.output.lower()
)
def test_tool_output_too_long_token_error():
"""Test tool handling when output exceeds token limit."""
def long_output_tool_token() -> str:
return too_long_token_output
ex = Exchange(
provider=MockProvider(
sequence=[
Message(
role="assistant",
content=[ToolUse(id="1", name="long_output_tool_token", parameters={})],
),
Message(
role="assistant",
content=[Text(text="Here is the completion after tool call")],
),
],
usage_dicts=[
{"usage": {"input_tokens": 12, "output_tokens": 23}},
{"usage": {"input_tokens": 12, "output_tokens": 23}},
],
),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
tools=[Tool.from_function(long_output_tool_token)],
moderator=PassiveModerator(),
)
ex.add(Message(role="user", content=[Text(text="test long output token")]))
ex.reply()
content = ex.messages[-2].content[0]
assert (
isinstance(content, ToolResult)
and content.is_error
and "output that was too long to handle" in content.output.lower()
)
@pytest.fixture(scope="function")
def normal_exchange() -> Exchange:
ex = Exchange(
provider=MockProvider(
sequence=[
Message(role="assistant", content=[Text(text="Message 1")]),
Message(role="assistant", content=[Text(text="Message 2")]),
Message(role="assistant", content=[Text(text="Message 3")]),
Message(role="assistant", content=[Text(text="Message 4")]),
Message(role="assistant", content=[Text(text="Message 5")]),
],
usage_dicts=[
{"usage": {"total_tokens": 10, "input_tokens": 5, "output_tokens": 5}},
{"usage": {"total_tokens": 28, "input_tokens": 10, "output_tokens": 18}},
{"usage": {"total_tokens": 33, "input_tokens": 28, "output_tokens": 5}},
{"usage": {"total_tokens": 40, "input_tokens": 32, "output_tokens": 8}},
{"usage": {"total_tokens": 50, "input_tokens": 40, "output_tokens": 10}},
],
),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
tools=(Tool.from_function(dummy_tool),),
moderator=PassiveModerator(),
checkpoint_data=CheckpointData(),
)
return ex
@pytest.fixture(scope="function")
def resumed_exchange() -> Exchange:
messages = [
Message(role="user", content=[Text(text="User message 1")]),
Message(role="assistant", content=[Text(text="Assistant Message 1")]),
Message(role="user", content=[Text(text="User message 2")]),
Message(role="assistant", content=[Text(text="Assistant Message 2")]),
Message(role="user", content=[Text(text="User message 3")]),
Message(role="assistant", content=[Text(text="Assistant Message 3")]),
]
provider = MockProvider(
sequence=[
Message(role="assistant", content=[Text(text="Assistant Message 4")]),
],
usage_dicts=[
{"usage": {"total_tokens": 40, "input_tokens": 32, "output_tokens": 8}},
],
)
ex = Exchange(
provider=provider,
messages=messages,
tools=[],
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
checkpoint_data=CheckpointData(),
moderator=PassiveModerator(),
)
return ex
def test_checkpoints_on_exchange(normal_exchange):
"""Test checkpoints on an exchange."""
ex = normal_exchange
ex.add(Message(role="user", content=[Text(text="User message")]))
ex.reply()
ex.add(Message(role="user", content=[Text(text="User message")]))
ex.reply()
ex.add(Message(role="user", content=[Text(text="User message")]))
ex.reply()
# Check if checkpoints are created correctly
checkpoints = ex.checkpoint_data.checkpoints
assert len(checkpoints) == 6
for i in range(len(ex.messages)):
# asserting that each message has a corresponding checkpoint
assert checkpoints[i].start_index == i
assert checkpoints[i].end_index == i
# Check if the messages are ordered correctly
assert [msg.content[0].text for msg in ex.messages] == [
"User message",
"Message 1",
"User message",
"Message 2",
"User message",
"Message 3",
]
assert no_overlapping_checkpoints(ex)
def test_checkpoints_on_resumed_exchange(resumed_exchange) -> None:
ex = resumed_exchange
ex.pop_last_message()
ex.reply()
checkpoints = ex.checkpoint_data.checkpoints
assert len(checkpoints) == 2
assert len(ex.messages) == 6
assert checkpoints[0].token_count == 32
assert checkpoints[0].start_index == 0
assert checkpoints[0].end_index == 4
assert checkpoints[1].token_count == 8
assert checkpoints[1].start_index == 5
assert checkpoints[1].end_index == 5
assert no_overlapping_checkpoints(ex)
def test_pop_last_checkpoint_on_resumed_exchange(resumed_exchange) -> None:
ex = resumed_exchange
ex.add(Message(role="user", content=[Text(text="Assistant Message 4")]))
ex.reply()
ex.pop_last_checkpoint()
assert len(ex.messages) == 7
assert len(ex.checkpoint_data.checkpoints) == 1
ex.pop_last_checkpoint()
assert len(ex.messages) == 0
assert len(ex.checkpoint_data.checkpoints) == 0
assert no_overlapping_checkpoints(ex)
def test_pop_last_checkpoint_on_normal_exchange(normal_exchange) -> None:
ex = normal_exchange
ex.add(Message(role="user", content=[Text(text="User message")]))
ex.reply()
ex.add(Message(role="user", content=[Text(text="User message")]))
ex.reply()
ex.pop_last_checkpoint()
ex.pop_last_checkpoint()
assert len(ex.messages) == 2
assert len(ex.checkpoint_data.checkpoints) == 2
assert no_overlapping_checkpoints(ex)
ex.add(Message(role="user", content=[Text(text="User message")]))
ex.pop_last_checkpoint()
assert len(ex.messages) == 1
assert len(ex.checkpoint_data.checkpoints) == 1
ex.reply()
assert len(ex.messages) == 2
assert len(ex.checkpoint_data.checkpoints) == 2
assert no_overlapping_checkpoints(ex)
def test_pop_first_message_no_messages():
ex = Exchange(
provider=MockProvider(sequence=[], usage_dicts=[]),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
tools=[Tool.from_function(dummy_tool)],
moderator=PassiveModerator(),
)
with pytest.raises(ValueError) as e:
ex.pop_first_message()
assert str(e.value) == "There are no messages to pop"
def test_pop_first_message_checkpoint_with_many_messages(resumed_exchange):
ex = resumed_exchange
ex.pop_last_message()
ex.reply()
assert len(ex.messages) == 6
assert len(ex.checkpoint_data.checkpoints) == 2
assert ex.checkpoint_data.checkpoints[0].start_index == 0
assert ex.checkpoint_data.checkpoints[0].end_index == 4
assert ex.checkpoint_data.checkpoints[1].start_index == 5
assert ex.checkpoint_data.checkpoints[1].end_index == 5
assert ex.checkpoint_data.message_index_offset == 0
assert no_overlapping_checkpoints(ex)
ex.pop_first_message()
assert len(ex.messages) == 5
assert len(ex.checkpoint_data.checkpoints) == 1
assert ex.checkpoint_data.checkpoints[0].start_index == 5
assert ex.checkpoint_data.checkpoints[0].end_index == 5
assert ex.checkpoint_data.message_index_offset == 1
assert no_overlapping_checkpoints(ex)
ex.pop_first_message()
assert len(ex.messages) == 4
assert len(ex.checkpoint_data.checkpoints) == 1
assert ex.checkpoint_data.checkpoints[0].start_index == 5
assert ex.checkpoint_data.checkpoints[0].end_index == 5
assert ex.checkpoint_data.message_index_offset == 2
assert no_overlapping_checkpoints(ex)
ex.pop_first_message()
assert len(ex.messages) == 3
assert len(ex.checkpoint_data.checkpoints) == 1
assert ex.checkpoint_data.checkpoints[0].start_index == 5
assert ex.checkpoint_data.checkpoints[0].end_index == 5
assert ex.checkpoint_data.message_index_offset == 3
assert no_overlapping_checkpoints(ex)
ex.pop_first_message()
assert len(ex.messages) == 2
assert len(ex.checkpoint_data.checkpoints) == 1
assert ex.checkpoint_data.checkpoints[0].start_index == 5
assert ex.checkpoint_data.checkpoints[0].end_index == 5
assert ex.checkpoint_data.message_index_offset == 4
assert no_overlapping_checkpoints(ex)
ex.pop_first_message()
assert len(ex.messages) == 1
assert len(ex.checkpoint_data.checkpoints) == 1
assert ex.checkpoint_data.checkpoints[0].start_index == 5
assert ex.checkpoint_data.checkpoints[0].end_index == 5
assert ex.checkpoint_data.message_index_offset == 5
assert no_overlapping_checkpoints(ex)
ex.pop_first_message()
assert len(ex.messages) == 0
assert len(ex.checkpoint_data.checkpoints) == 0
assert ex.checkpoint_data.message_index_offset == 0
assert no_overlapping_checkpoints(ex)
with pytest.raises(ValueError) as e:
ex.pop_first_message()
assert str(e.value) == "There are no messages to pop"
def test_varied_message_manipulation(normal_exchange):
ex = normal_exchange
ex.add(Message(role="user", content=[Text(text="User message 1")]))
ex.reply()
ex.pop_first_message()
ex.add(Message(role="user", content=[Text(text="User message 2")]))
ex.reply()
assert len(ex.messages) == 3
assert len(ex.checkpoint_data.checkpoints) == 3
assert ex.checkpoint_data.message_index_offset == 1
# (start, end)
# (1, 1), (2, 2), (3, 3)
# actual_index_in_messages_arr = any checkpoint index - offset
assert no_overlapping_checkpoints(ex)
for i in range(3):
assert ex.checkpoint_data.checkpoints[i].start_index == i + 1
assert ex.checkpoint_data.checkpoints[i].end_index == i + 1
ex.pop_last_message()
assert len(ex.messages) == 2
assert len(ex.checkpoint_data.checkpoints) == 2
assert ex.checkpoint_data.message_index_offset == 1
assert no_overlapping_checkpoints(ex)
for i in range(2):
assert ex.checkpoint_data.checkpoints[i].start_index == i + 1
assert ex.checkpoint_data.checkpoints[i].end_index == i + 1
ex.add(Message(role="assistant", content=[Text(text="Assistant message")]))
ex.add(Message(role="user", content=[Text(text="User message 3")]))
ex.reply()
assert len(ex.messages) == 5
assert len(ex.checkpoint_data.checkpoints) == 4
assert ex.checkpoint_data.message_index_offset == 1
assert no_overlapping_checkpoints(ex)
assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(1, 1), (2, 2), (3, 4), (5, 5)]
ex.pop_last_checkpoint()
assert len(ex.messages) == 4
assert len(ex.checkpoint_data.checkpoints) == 3
assert ex.checkpoint_data.message_index_offset == 1
assert no_overlapping_checkpoints(ex)
assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(1, 1), (2, 2), (3, 4)]
ex.pop_first_message()
assert len(ex.messages) == 3
assert len(ex.checkpoint_data.checkpoints) == 2
assert ex.checkpoint_data.message_index_offset == 2
assert no_overlapping_checkpoints(ex)
assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2), (3, 4)]
ex.pop_last_message()
assert len(ex.messages) == 2
assert len(ex.checkpoint_data.checkpoints) == 1
assert ex.checkpoint_data.message_index_offset == 2
assert no_overlapping_checkpoints(ex)
assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2)]
ex.pop_last_message()
assert len(ex.messages) == 1
assert len(ex.checkpoint_data.checkpoints) == 1
assert ex.checkpoint_data.message_index_offset == 2
assert no_overlapping_checkpoints(ex)
assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2)]
ex.add(Message(role="assistant", content=[Text(text="Assistant message")]))
ex.add(Message(role="user", content=[Text(text="User message 5")]))
ex.pop_last_checkpoint()
assert len(ex.messages) == 0
assert len(ex.checkpoint_data.checkpoints) == 0
ex.add(Message(role="user", content=[Text(text="User message 6")]))
ex.reply()
assert len(ex.messages) == 2
assert len(ex.checkpoint_data.checkpoints) == 2
assert ex.checkpoint_data.message_index_offset == 2
assert no_overlapping_checkpoints(ex)
assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2), (3, 3)]
ex.pop_last_message()
assert len(ex.messages) == 1
assert len(ex.checkpoint_data.checkpoints) == 1
assert ex.checkpoint_data.message_index_offset == 2
assert no_overlapping_checkpoints(ex)
assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2)]
ex.pop_first_message()
assert len(ex.messages) == 0
assert len(ex.checkpoint_data.checkpoints) == 0
assert ex.checkpoint_data.message_index_offset == 0
ex.add(Message(role="user", content=[Text(text="User message 7")]))
ex.pop_last_message()
assert len(ex.messages) == 0
assert len(ex.checkpoint_data.checkpoints) == 0
assert ex.checkpoint_data.message_index_offset == 0
def test_pop_last_message_when_no_checkpoints_but_messages_present(normal_exchange):
ex = normal_exchange
ex.add(Message(role="user", content=[Text(text="User message")]))
ex.pop_last_message()
assert len(ex.messages) == 0
assert len(ex.checkpoint_data.checkpoints) == 0
assert ex.checkpoint_data.message_index_offset == 0
def test_pop_first_message_when_no_checkpoints_but_message_present(normal_exchange):
ex = normal_exchange
ex.add(Message(role="user", content=[Text(text="User message")]))
with pytest.raises(ValueError) as e:
ex.pop_first_message()
assert str(e.value) == "There must be at least one checkpoint to pop the first message"
def test_pop_first_checkpoint_size_n(resumed_exchange):
ex = resumed_exchange
ex.pop_last_message() # needed because the last message is an assistant message
ex.reply()
ex.pop_first_checkpoint()
assert ex.checkpoint_data.message_index_offset == 5
assert len(ex.checkpoint_data.checkpoints) == 1
assert len(ex.messages) == 1
ex.pop_first_checkpoint()
assert ex.checkpoint_data.message_index_offset == 0
assert len(ex.checkpoint_data.checkpoints) == 0
assert len(ex.messages) == 0
def test_pop_first_checkpoint_size_1(normal_exchange):
ex = normal_exchange
ex.add(Message(role="user", content=[Text(text="User message")]))
ex.reply()
ex.pop_first_checkpoint()
assert ex.checkpoint_data.message_index_offset == 1
assert len(ex.checkpoint_data.checkpoints) == 1
assert len(ex.messages) == 1
ex.pop_first_checkpoint()
assert ex.checkpoint_data.message_index_offset == 0
assert len(ex.checkpoint_data.checkpoints) == 0
assert len(ex.messages) == 0
def test_pop_first_checkpoint_no_checkpoints(normal_exchange):
ex = normal_exchange
with pytest.raises(ValueError) as e:
ex.pop_first_checkpoint()
assert str(e.value) == "There are no checkpoints to pop"
def test_prepend_checkpointed_message_empty_exchange(normal_exchange):
ex = normal_exchange
ex.prepend_checkpointed_message(Message(role="assistant", content=[Text(text="Assistant message")]), 10)
assert ex.checkpoint_data.message_index_offset == 0
assert len(ex.checkpoint_data.checkpoints) == 1
assert ex.checkpoint_data.checkpoints[0].start_index == 0
assert ex.checkpoint_data.checkpoints[0].end_index == 0
ex.add(Message(role="user", content=[Text(text="User message")]))
ex.reply()
assert ex.checkpoint_data.message_index_offset == 0
assert len(ex.checkpoint_data.checkpoints) == 3
assert len(ex.messages) == 3
assert no_overlapping_checkpoints(ex)
ex.pop_first_checkpoint()
assert ex.checkpoint_data.message_index_offset == 1
assert len(ex.checkpoint_data.checkpoints) == 2
assert len(ex.messages) == 2
assert no_overlapping_checkpoints(ex)
ex.prepend_checkpointed_message(Message(role="assistant", content=[Text(text="Assistant message")]), 10)
assert ex.checkpoint_data.message_index_offset == 0
assert len(ex.checkpoint_data.checkpoints) == 3
assert len(ex.messages) == 3
assert no_overlapping_checkpoints(ex)
def test_generate_successful_response_on_first_try(normal_exchange):
ex = normal_exchange
ex.add(Message(role="user", content=[Text("Hello")]))
ex.generate()
def test_rewind_in_normal_exchange(normal_exchange):
ex = normal_exchange
ex.rewind()
assert len(ex.messages) == 0
assert len(ex.checkpoint_data.checkpoints) == 0
ex.add(Message(role="user", content=[Text("Hello")]))
ex.generate()
ex.add(Message(role="user", content=[Text("Hello")]))
# testing if it works with a user text message at the end
ex.rewind()
assert len(ex.messages) == 2
assert len(ex.checkpoint_data.checkpoints) == 2
ex.add(Message(role="user", content=[Text("Hello")]))
ex.generate()
# testing if it works with a non-user text message at the end
ex.rewind()
assert len(ex.messages) == 2
assert len(ex.checkpoint_data.checkpoints) == 2
def test_rewind_with_tool_usage():
# simulating a real exchange with tool usage
ex = Exchange(
provider=MockProvider(
sequence=[
Message.assistant("Hello!"),
Message(
role="assistant",
content=[ToolUse(id="1", name="dummy_tool", parameters={})],
),
Message(
role="assistant",
content=[ToolUse(id="2", name="dummy_tool", parameters={})],
),
Message.assistant("Done!"),
],
usage_dicts=[
{"usage": {"input_tokens": 12, "output_tokens": 23}},
{"usage": {"input_tokens": 27, "output_tokens": 44}},
{"usage": {"input_tokens": 50, "output_tokens": 56}},
{"usage": {"input_tokens": 60, "output_tokens": 76}},
],
),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
tools=[Tool.from_function(dummy_tool)],
moderator=PassiveModerator(),
)
ex.add(Message(role="user", content=[Text(text="test")]))
ex.reply()
ex.add(Message(role="user", content=[Text(text="kick it off!")]))
ex.reply()
# removing the last message to simulate not getting a response
ex.pop_last_message()
# calling rewind to last user message
ex.rewind()
assert len(ex.messages) == 2
assert len(ex.checkpoint_data.checkpoints) == 2
assert no_overlapping_checkpoints(ex)
assert ex.messages[0].content[0].text == "test"
assert type(ex.messages[1].content[0]) is Text
assert ex.messages[1].role == "assistant"

View File

@@ -0,0 +1,33 @@
from unittest.mock import MagicMock
from exchange.exchange import Exchange
from exchange.message import Message
from exchange.moderators.passive import PassiveModerator
from exchange.providers.base import Provider
from exchange.tool import Tool
from exchange.token_usage_collector import _TokenUsageCollector
MODEL_NAME = "test-model"
def create_exchange(mock_provider, dummy_tool):
return Exchange(
provider=mock_provider,
model=MODEL_NAME,
system="test-system",
tools=(Tool.from_function(dummy_tool),),
messages=[],
moderator=PassiveModerator(),
)
def test_exchange_generate_collect_usage(usage_factory, dummy_tool, monkeypatch):
mock_provider = MagicMock(spec=Provider)
mock_usage_collector = MagicMock(spec=_TokenUsageCollector)
usage = usage_factory()
mock_provider.complete.return_value = (Message.assistant("msg"), usage)
exchange = create_exchange(mock_provider, dummy_tool)
monkeypatch.setattr("exchange.exchange._token_usage_collector", mock_usage_collector)
exchange.generate()
mock_usage_collector.collect.assert_called_once_with(MODEL_NAME, usage)

View File

@@ -0,0 +1,48 @@
import pytest
from attr.exceptions import FrozenInstanceError
from exchange.content import Text
from exchange.exchange import Exchange
from exchange.moderators import PassiveModerator
from exchange.message import Message
from exchange.providers import Provider, Usage
from exchange.tool import Tool
class MockProvider(Provider):
def complete(self, model, system, messages, tools=None):
return Message(role="assistant", content=[Text(text="This is a mock response.")]), Usage.from_dict(
{"total_tokens": 35}
)
def test_exchange_immutable(dummy_tool):
# Create an instance of Exchange
provider = MockProvider()
# intentionally setting a list instead of tuple on tools, it should be converted
exchange = Exchange(
provider=provider,
model="test-model",
system="test-system",
tools=(Tool.from_function(dummy_tool),),
messages=[Message(role="user", content=[Text(text="Hello!")])],
moderator=PassiveModerator(),
)
# Try to directly modify a field (should raise an error)
with pytest.raises(FrozenInstanceError):
exchange.system = ""
with pytest.raises(AttributeError):
exchange.tools.append("anything")
# Replace method should return a new instance with deepcopy of messages
new_exchange = exchange.replace(system="changed")
assert new_exchange.system == "changed"
assert len(exchange.messages) == 1
assert len(new_exchange.messages) == 1
# Ensure that the messages are deep copied
new_exchange.messages[0].content[0].text = "Changed!"
assert exchange.messages[0].content[0].text == "Hello!"
assert new_exchange.messages[0].content[0].text == "Changed!"

Binary file not shown.

After

Width:  |  Height:  |  Size: 57 KiB

View File

@@ -0,0 +1,89 @@
import os
import pytest
from exchange.exchange import Exchange
from exchange.message import Message
from exchange.moderators import ContextTruncate
from exchange.providers import get_provider
from exchange.providers.ollama import OLLAMA_MODEL
from exchange.tool import Tool
from tests.conftest import read_file
too_long_chars = "x" * (2**20 + 1)
cases = [
# Set seed and temperature for more determinism, to avoid flakes
(get_provider("ollama"), os.getenv("OLLAMA_MODEL", OLLAMA_MODEL), dict(seed=3, temperature=0.1)),
(get_provider("openai"), os.getenv("OPENAI_MODEL", "gpt-4o-mini"), dict()),
(get_provider("azure"), os.getenv("AZURE_MODEL", "gpt-4o-mini"), dict()),
(get_provider("databricks"), "databricks-meta-llama-3-70b-instruct", dict()),
(get_provider("bedrock"), "anthropic.claude-3-5-sonnet-20240620-v1:0", dict()),
(get_provider("google"), "gemini-1.5-flash", dict()),
]
@pytest.mark.integration
@pytest.mark.parametrize("provider,model,kwargs", cases)
def test_simple(provider, model, kwargs):
provider = provider.from_env()
ex = Exchange(
provider=provider,
model=model,
moderator=ContextTruncate(model),
system="You are a helpful assistant.",
generation_args=kwargs,
)
ex.add(Message.user("Who is the most famous wizard from the lord of the rings"))
response = ex.reply()
# It's possible this can be flakey, but in experience so far haven't seen it
assert "gandalf" in response.text.lower()
@pytest.mark.integration
@pytest.mark.parametrize("provider,model,kwargs", cases)
def test_tools(provider, model, kwargs, tmp_path):
provider = provider.from_env()
ex = Exchange(
provider=provider,
model=model,
moderator=ContextTruncate(model),
system="You are a helpful assistant. Expect to need to read a file using read_file.",
tools=(Tool.from_function(read_file),),
generation_args=kwargs,
)
ex.add(Message.user("What are the contents of this file? test.txt"))
response = ex.reply()
assert "hello exchange" in response.text.lower()
@pytest.mark.integration
@pytest.mark.parametrize("provider,model,kwargs", cases)
def test_tool_use_output_chars(provider, model, kwargs):
provider = provider.from_env()
def get_password() -> str:
"""Return the password for authentication"""
return too_long_chars
ex = Exchange(
provider=provider,
model=model,
moderator=ContextTruncate(model),
system="You are a helpful assistant. Expect to need to authenticate using get_password.",
tools=(Tool.from_function(get_password),),
generation_args=kwargs,
)
ex.add(Message.user("Can you authenticate this session by responding with the password"))
ex.reply()
# Without our error handling, this would raise
# string too long. Expected a string with maximum length 1048576, but got a string with length ...

View File

@@ -0,0 +1,44 @@
import os
import pytest
from exchange.content import ToolResult, ToolUse
from exchange.exchange import Exchange
from exchange.message import Message
from exchange.moderators import ContextTruncate
from exchange.providers import get_provider
cases = [
(get_provider("openai"), os.getenv("OPENAI_MODEL", "gpt-4o-mini")),
]
@pytest.mark.integration # skipped in CI/CD
@pytest.mark.parametrize("provider,model", cases)
def test_simple(provider, model):
provider = provider.from_env()
ex = Exchange(
provider=provider,
model=model,
moderator=ContextTruncate(model),
system="You are a helpful assistant.",
)
ex.add(Message.user("What does the first entry in the menu say?"))
ex.add(
Message(
role="assistant",
content=[ToolUse(id="xyz", name="screenshot", parameters={})],
)
)
ex.add(
Message(
role="user",
content=[ToolResult(tool_use_id="xyz", output='"image:tests/test_image.png"')],
)
)
response = ex.reply()
# It's possible this can be flakey, but in experience so far haven't seen it
assert "ask goose" in response.text.lower()

View File

@@ -0,0 +1,96 @@
import subprocess
from pathlib import Path
import pytest
from exchange.message import Message
from exchange.content import Text, ToolUse, ToolResult
def test_user_message():
user_message = Message.user("abcd")
assert user_message.role == "user"
assert user_message.text == "abcd"
def test_assistant_message():
assistant_message = Message.assistant("abcd")
assert assistant_message.role == "assistant"
assert assistant_message.text == "abcd"
def test_message_tool_use():
from exchange.content import ToolUse
tu1 = ToolUse(id="1", name="tool", parameters={})
tu2 = ToolUse(id="2", name="tool", parameters={})
message = Message(role="assistant", content=[tu1, tu2])
assert len(message.tool_use) == 2
assert message.tool_use[0].name == "tool"
def test_message_tool_result():
from exchange.content import ToolResult
tr1 = ToolResult(tool_use_id="1", output="result")
tr2 = ToolResult(tool_use_id="2", output="result")
message = Message(role="user", content=[tr1, tr2])
assert len(message.tool_result) == 2
assert message.tool_result[0].output == "result"
def test_message_load(tmpdir):
# To emulate the expected relative lookup, we need to create a mock code dir
# and run the load in a subprocess
test_dir = Path(tmpdir)
# Create a temporary Jinja template file in the test_dir
template_content = "hello {{ name }} {% include 'relative.jinja' %}"
template_path = test_dir / "template.jinja"
template_path.write_text(template_content)
relative_content = "and {{ name2 }}"
relative_path = test_dir / "relative.jinja"
relative_path.write_text(relative_content)
# Create a temporary Python file in the sub_dir that calls the load method with a relative path
python_file_content = """
from exchange.message import Message
def test_function():
message = Message.load('template.jinja', name="a", name2="b")
assert message.text == "hello a and b"
assert message.role == "user"
test_function()
"""
python_file_path = test_dir / "test_script.py"
python_file_path.write_text(python_file_content)
# Execute the temporary Python file to test the relative lookup functionality
result = subprocess.run(["python3", str(python_file_path)])
assert result.returncode == 0
def test_message_validation():
# Valid user message
message = Message(role="user", content=[Text(text="Hello")])
assert message.text == "Hello"
# Valid assistant message
message = Message(role="assistant", content=[Text(text="Hello")])
assert message.text == "Hello"
# Invalid message: user with tool_use
with pytest.raises(ValueError):
Message(
role="user",
content=[Text(text=""), ToolUse(id="1", name="tool", parameters={})],
)
# Invalid message: assistant with tool_result
with pytest.raises(ValueError):
Message(
role="assistant",
content=[Text(text=""), ToolResult(tool_use_id="1", output="result")],
)

View File

@@ -0,0 +1,227 @@
import pytest
from exchange import Exchange, Message
from exchange.content import ToolResult, ToolUse
from exchange.moderators.passive import PassiveModerator
from exchange.moderators.summarizer import ContextSummarizer
from exchange.providers import Usage
class MockProvider:
def complete(self, model, system, messages, tools):
assistant_message_text = "Summarized content here."
output_tokens = len(assistant_message_text)
total_input_tokens = sum(len(msg.text) for msg in messages)
if not messages or messages[-1].role == "assistant":
message = Message.user(assistant_message_text)
else:
message = Message.assistant(assistant_message_text)
total_tokens = total_input_tokens + output_tokens
usage = Usage(
input_tokens=total_input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
)
return message, usage
@pytest.fixture
def exchange_instance():
ex = Exchange(
provider=MockProvider(),
model="test-model",
system="test-system",
messages=[
Message.user("Hi, can you help me with my homework?"),
Message.assistant("Of course! What do you need help with?"),
Message.user("I need help with math problems."),
Message.assistant("Sure, I can help with that. Let's get started."),
Message.user("Can you also help with my science homework?"),
Message.assistant("Yes, I can help with science too."),
Message.user("That's great! How about history?"),
Message.assistant("Of course, I can help with history as well."),
Message.user("Thanks! You're very helpful."),
Message.assistant("You're welcome! I'm here to help."),
],
moderator=PassiveModerator(),
)
return ex
@pytest.fixture
def summarizer_instance():
return ContextSummarizer(max_tokens=300)
def test_context_summarizer_rewrite(exchange_instance: Exchange, summarizer_instance: ContextSummarizer):
# Pre-checks
assert len(exchange_instance.messages) == 10
exchange_instance.generate()
# the exchange instance has a PassiveModerator so the messages are not truncated nor summarized
assert len(exchange_instance.messages) == 11
assert len(exchange_instance.checkpoint_data.checkpoints) == 2
# we now tell the summarizer to summarize the exchange
summarizer_instance.rewrite(exchange_instance)
assert exchange_instance.checkpoint_data.total_token_count <= 200
assert len(exchange_instance.messages) == 2
# Assert that summarized content is the first message
first_message = exchange_instance.messages[0]
assert first_message.role == "user" or first_message.role == "assistant"
assert any("summarized" in content.text.lower() for content in first_message.content)
# Ensure roles alternate in the output
for i in range(1, len(exchange_instance.messages)):
assert (
exchange_instance.messages[i - 1].role != exchange_instance.messages[i].role
), "Messages must alternate between user and assistant"
MESSAGE_SEQUENCE = [
Message.user("Hi, can you help me with my homework?"),
Message.assistant("Of course! What do you need help with?"),
Message.user("I need help with math problems."),
Message.assistant("Sure, I can help with that. Let's get started."),
Message.user("What is 2 + 2, 3*3, 9/5, 2*20, 14/2?"),
Message(
role="assistant",
content=[ToolUse(id="1", name="add", parameters={"a": 2, "b": 2})],
),
Message(role="user", content=[ToolResult(tool_use_id="1", output="4")]),
Message(
role="assistant",
content=[ToolUse(id="2", name="multiply", parameters={"a": 3, "b": 3})],
),
Message(role="user", content=[ToolResult(tool_use_id="2", output="9")]),
Message(
role="assistant",
content=[ToolUse(id="3", name="divide", parameters={"a": 9, "b": 5})],
),
Message(role="user", content=[ToolResult(tool_use_id="3", output="1.8")]),
Message(
role="assistant",
content=[ToolUse(id="4", name="multiply", parameters={"a": 2, "b": 20})],
),
Message(role="user", content=[ToolResult(tool_use_id="4", output="40")]),
Message(
role="assistant",
content=[ToolUse(id="5", name="divide", parameters={"a": 14, "b": 2})],
),
Message(role="user", content=[ToolResult(tool_use_id="5", output="7")]),
Message.assistant("I'm done calculating the answers to your math questions."),
Message.user("Can you also help with my science homework?"),
Message.assistant("Yes, I can help with science too."),
Message.user("What is the speed of light? The frequency of a photon? The mass of an electron?"),
Message(
role="assistant",
content=[ToolUse(id="6", name="speed_of_light", parameters={})],
),
Message(role="user", content=[ToolResult(tool_use_id="6", output="299,792,458 m/s")]),
Message(
role="assistant",
content=[ToolUse(id="7", name="photon_frequency", parameters={})],
),
Message(role="user", content=[ToolResult(tool_use_id="7", output="2.418 x 10^14 Hz")]),
Message(role="assistant", content=[ToolUse(id="8", name="electron_mass", parameters={})]),
Message(
role="user",
content=[ToolResult(tool_use_id="8", output="9.10938356 x 10^-31 kg")],
),
Message.assistant("I'm done calculating the answers to your science questions."),
Message.user("That's great! How about history?"),
Message.assistant("Of course, I can help with history as well."),
Message.user("Thanks! You're very helpful."),
Message.assistant("You're welcome! I'm here to help."),
]
class AnotherMockProvider:
def __init__(self):
self.sequence = MESSAGE_SEQUENCE
self.current_index = 1
self.summarize_next = False
self.summarized_count = 0
def complete(self, model, system, messages, tools):
system_prompt_tokens = 100
input_token_count = system_prompt_tokens
message = self.sequence[self.current_index]
if self.summarize_next:
text = "Summary message here"
self.summarize_next = False
self.summarized_count += 1
return Message.assistant(text=text), Usage(
# in this case, input tokens can really be whatever
input_tokens=40,
output_tokens=len(text) * 2,
total_tokens=40 + len(text) * 2,
)
if len(messages) > 0 and type(messages[0].content[0]) is ToolResult:
raise ValueError("ToolResult should not be the first message")
if len(messages) == 1 and messages[0].text == "a":
# adding a +1 for the "a"
return Message.assistant("Getting system prompt size"), Usage(
input_tokens=80 + 1, output_tokens=20, total_tokens=system_prompt_tokens + 1
)
for i in range(len(messages)):
if type(messages[i].content[0]) in (ToolResult, ToolUse):
input_token_count += 10
else:
input_token_count += len(messages[i].text) * 2
if type(message.content[0]) in (ToolResult, ToolUse):
output_tokens = 10
else:
output_tokens = len(message.text) * 2
total_tokens = input_token_count + output_tokens
if total_tokens > 300:
self.summarize_next = True
usage = Usage(
input_tokens=input_token_count,
output_tokens=output_tokens,
total_tokens=total_tokens,
)
self.current_index += 2
return message, usage
@pytest.fixture
def conversation_exchange_instance():
ex = Exchange(
provider=AnotherMockProvider(),
model="test-model",
system="test-system",
moderator=ContextSummarizer(max_tokens=300),
# TODO: make it work with an offset so we don't have to send off requests basically
# at every generate step
)
return ex
def test_summarizer_generic_conversation(conversation_exchange_instance: Exchange):
i = 0
while i < len(MESSAGE_SEQUENCE):
next_message = MESSAGE_SEQUENCE[i]
conversation_exchange_instance.add(next_message)
message = conversation_exchange_instance.generate()
if message.text != "Summary message here":
i += 2
checkpoints = conversation_exchange_instance.checkpoint_data.checkpoints
assert conversation_exchange_instance.checkpoint_data.total_token_count == 570
assert len(checkpoints) == 10
assert len(conversation_exchange_instance.messages) == 10
assert checkpoints[0].start_index == 20
assert checkpoints[0].end_index == 20
assert checkpoints[-1].start_index == 29
assert checkpoints[-1].end_index == 29
assert conversation_exchange_instance.checkpoint_data.message_index_offset == 20
assert conversation_exchange_instance.provider.summarized_count == 12
assert conversation_exchange_instance.moderator.system_prompt_token_count == 100

View File

@@ -0,0 +1,24 @@
from exchange.token_usage_collector import _TokenUsageCollector
def test_collect(usage_factory):
usage_collector = _TokenUsageCollector()
usage_collector.collect("model1", usage_factory(100, 1000, 1100))
usage_collector.collect("model1", usage_factory(200, 2000, 2200))
usage_collector.collect("model2", usage_factory(400, 4000, 4400))
usage_collector.collect("model3", usage_factory(500, 5000, 5500))
usage_collector.collect("model3", usage_factory(600, 6000, 6600))
assert usage_collector.get_token_usage_group_by_model() == {
"model1": usage_factory(300, 3000, 3300),
"model2": usage_factory(400, 4000, 4400),
"model3": usage_factory(1100, 11000, 12100),
}
def test_collect_with_non_input_or_output_token(usage_factory):
usage_collector = _TokenUsageCollector()
usage_collector.collect("model1", usage_factory(100, None, None))
usage_collector.collect("model1", usage_factory(None, 2000, None))
assert usage_collector.get_token_usage_group_by_model() == {
"model1": usage_factory(100, 2000, 0),
}

View File

@@ -0,0 +1,161 @@
import attrs
from exchange.tool import Tool
def get_current_weather(location: str) -> None:
"""Get the current weather in a given location
Args:
location (str): The city and state, e.g. San Francisco, CA
"""
pass
def test_load():
tool = Tool.from_function(get_current_weather)
expected = {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
"required": ["location"],
},
"function": get_current_weather,
}
assert attrs.asdict(tool) == expected
def another_function(
param1: int,
param2: str,
param3: bool,
param4: float,
param5: list[int],
param6: dict[str, float],
) -> None:
"""
This is another example function with various types
Args:
param1 (int): Description for param1
param2 (str): Description for param2
param3 (bool): Description for param3
param4 (float): Description for param4
param5 (list[int]): Description for param5
param6 (dict[str, float]): Description for param6
"""
pass
def test_load_types():
tool = Tool.from_function(another_function)
expected_schema = {
"type": "object",
"properties": {
"param1": {"type": "integer", "description": "Description for param1"},
"param2": {"type": "string", "description": "Description for param2"},
"param3": {"type": "boolean", "description": "Description for param3"},
"param4": {"type": "number", "description": "Description for param4"},
"param5": {
"type": "array",
"items": {"type": "integer"},
"description": "Description for param5",
},
"param6": {
"type": "object",
"additionalProperties": {"type": "number"},
"description": "Description for param6",
},
},
"required": ["param1", "param2", "param3", "param4", "param5", "param6"],
}
assert tool.parameters == expected_schema
def numpy_function(param1: int, param2: str) -> None:
"""
This function uses numpy style docstrings
Parameters
----------
param1 : int
Description for param1
param2 : str
Description for param2
"""
pass
def test_load_numpy_style():
tool = Tool.from_function(numpy_function)
expected_schema = {
"type": "object",
"properties": {
"param1": {"type": "integer", "description": "Description for param1"},
"param2": {"type": "string", "description": "Description for param2"},
},
"required": ["param1", "param2"],
}
assert tool.parameters == expected_schema
def sphinx_function(param1: int, param2: str, param3: bool) -> None:
"""
This function uses sphinx style docstrings
:param param1: Description for param1
:type param1: int
:param param2: Description for param2
:type param2: str
:param param3: Description for param3
:type param3: bool
"""
pass
def test_load_sphinx_style():
tool = Tool.from_function(sphinx_function)
expected_schema = {
"type": "object",
"properties": {
"param1": {"type": "integer", "description": "Description for param1"},
"param2": {"type": "string", "description": "Description for param2"},
"param3": {"type": "boolean", "description": "Description for param3"},
},
"required": ["param1", "param2", "param3"],
}
assert tool.parameters == expected_schema
class FunctionLike:
def __init__(self, state: int) -> None:
self.state = state
def __call__(self, param1: int) -> int:
"""Example
Args:
param1 (int): Description for param1
"""
return self.state + param1
def test_load_stateful_class():
tool = Tool.from_function(FunctionLike(1))
expected_schema = {
"type": "object",
"properties": {
"param1": {"type": "integer", "description": "Description for param1"},
},
"required": ["param1"],
}
assert tool.parameters == expected_schema
assert tool.function(2) == 3

View File

@@ -0,0 +1,132 @@
import pytest
from exchange import Exchange
from exchange.content import ToolResult, ToolUse
from exchange.message import Message
from exchange.moderators.truncate import ContextTruncate
from exchange.providers import Provider, Usage
MAX_TOKENS = 300
SYSTEM_PROMPT_TOKENS = 100
MESSAGE_SEQUENCE = [
Message.user("Hi, can you help me with my homework?"),
Message.assistant("Of course! What do you need help with?"),
Message.user("I need help with math problems."),
Message.assistant("Sure, I can help with that. Let's get started."),
Message.user("What is 2 + 2, 3*3, 9/5, 2*20, 14/2?"),
Message(
role="assistant",
content=[ToolUse(id="1", name="add", parameters={"a": 2, "b": 2})],
),
Message(role="user", content=[ToolResult(tool_use_id="1", output="4")]),
Message(
role="assistant",
content=[ToolUse(id="2", name="multiply", parameters={"a": 3, "b": 3})],
),
Message(role="user", content=[ToolResult(tool_use_id="2", output="9")]),
Message(
role="assistant",
content=[ToolUse(id="3", name="divide", parameters={"a": 9, "b": 5})],
),
Message(role="user", content=[ToolResult(tool_use_id="3", output="1.8")]),
Message(
role="assistant",
content=[ToolUse(id="4", name="multiply", parameters={"a": 2, "b": 20})],
),
Message(role="user", content=[ToolResult(tool_use_id="4", output="40")]),
Message(
role="assistant",
content=[ToolUse(id="5", name="divide", parameters={"a": 14, "b": 2})],
),
Message(role="user", content=[ToolResult(tool_use_id="5", output="7")]),
Message.assistant("I'm done calculating the answers to your math questions."),
Message.user("Can you also help with my science homework?"),
Message.assistant("Yes, I can help with science too."),
Message.user("What is the speed of light? The frequency of a photon? The mass of an electron?"),
Message(
role="assistant",
content=[ToolUse(id="6", name="speed_of_light", parameters={})],
),
Message(role="user", content=[ToolResult(tool_use_id="6", output="299,792,458 m/s")]),
Message(
role="assistant",
content=[ToolUse(id="7", name="photon_frequency", parameters={})],
),
Message(role="user", content=[ToolResult(tool_use_id="7", output="2.418 x 10^14 Hz")]),
Message(role="assistant", content=[ToolUse(id="8", name="electron_mass", parameters={})]),
Message(
role="user",
content=[ToolResult(tool_use_id="8", output="9.10938356 x 10^-31 kg")],
),
Message.assistant("I'm done calculating the answers to your science questions."),
Message.user("That's great! How about history?"),
Message.assistant("Of course, I can help with history as well."),
Message.user("Thanks! You're very helpful."),
Message.assistant("You're welcome! I'm here to help."),
]
class TruncateLinearProvider(Provider):
def __init__(self):
self.sequence = MESSAGE_SEQUENCE
self.current_index = 1
self.summarize_next = False
self.summarized_count = 0
def complete(self, model, system, messages, tools):
input_token_count = SYSTEM_PROMPT_TOKENS
message = self.sequence[self.current_index]
if len(messages) > 0 and type(messages[0].content[0]) is ToolResult:
raise ValueError("ToolResult should not be the first message")
if len(messages) == 1 and messages[0].text == "a":
# adding a +1 for the "a"
return Message.assistant("Getting system prompt size"), Usage(
input_tokens=80 + 1, output_tokens=20, total_tokens=SYSTEM_PROMPT_TOKENS + 1
)
for i in range(len(messages)):
if type(messages[i].content[0]) in (ToolResult, ToolUse):
input_token_count += 10
else:
input_token_count += len(messages[i].text) * 2
if type(message.content[0]) in (ToolResult, ToolUse):
output_tokens = 10
else:
output_tokens = len(message.text) * 2
total_tokens = input_token_count + output_tokens
usage = Usage(
input_tokens=input_token_count,
output_tokens=output_tokens,
total_tokens=total_tokens,
)
self.current_index += 2
return message, usage
@pytest.fixture
def conversation_exchange_instance():
ex = Exchange(
provider=TruncateLinearProvider(),
model="test-model",
system="test-system",
moderator=ContextTruncate(max_tokens=500),
)
return ex
def test_truncate_on_generic_conversation(conversation_exchange_instance: Exchange):
i = 0
while i < len(MESSAGE_SEQUENCE):
next_message = MESSAGE_SEQUENCE[i]
conversation_exchange_instance.add(next_message)
message = conversation_exchange_instance.generate()
if message.text != "Summary message here":
i += 2
# ensure the total token count is not anything exhorbitant
assert conversation_exchange_instance.checkpoint_data.total_token_count < 700
assert conversation_exchange_instance.moderator.system_prompt_token_count == 100

View File

@@ -0,0 +1,125 @@
import pytest
from exchange import utils
from unittest.mock import patch
from exchange.message import Message
from exchange.content import Text, ToolResult
from exchange.providers.utils import messages_to_openai_spec, encode_image
def test_encode_image():
image_path = "tests/test_image.png"
encoded_image = encode_image(image_path)
# Adjust this string based on the actual initial part of your base64-encoded image.
expected_start = "iVBORw0KGgo"
assert encoded_image.startswith(expected_start)
def test_create_object_id() -> None:
prefix = "test"
object_id = utils.create_object_id(prefix)
assert object_id.startswith(prefix + "_")
assert len(object_id) == len(prefix) + 1 + 24 # prefix + _ + 24 chars
def test_compact() -> None:
content = "This is \n\n a test"
compacted = utils.compact(content)
assert compacted == "This is a test"
def test_parse_docstring() -> None:
def dummy_func(a, b, c):
"""
This function does something.
Args:
a (int): The first parameter.
b (str): The second parameter.
c (list): The third parameter.
"""
pass
description, parameters = utils.parse_docstring(dummy_func)
assert description == "This function does something."
assert parameters == [
{"name": "a", "annotation": "int", "description": "The first parameter."},
{"name": "b", "annotation": "str", "description": "The second parameter."},
{"name": "c", "annotation": "list", "description": "The third parameter."},
]
def test_parse_docstring_no_description() -> None:
def dummy_func(a, b, c):
"""
Args:
a (int): The first parameter.
b (str): The second parameter.
c (list): The third parameter.
"""
pass
with pytest.raises(ValueError) as e:
utils.parse_docstring(dummy_func)
assert "Attempted to load from a function" in str(e.value)
def test_json_schema() -> None:
def dummy_func(a: int, b: str, c: list) -> None:
pass
schema = utils.json_schema(dummy_func)
assert schema == {
"type": "object",
"properties": {
"a": {"type": "integer"},
"b": {"type": "string"},
"c": {"type": "string"},
},
"required": ["a", "b", "c"],
}
def test_load_plugins() -> None:
class DummyEntryPoint:
def __init__(self, name, plugin):
self.name = name
self.plugin = plugin
def load(self):
return self.plugin
with patch("exchange.utils.entry_points") as entry_points_mock:
entry_points_mock.return_value = [
DummyEntryPoint("plugin1", object()),
DummyEntryPoint("plugin2", object()),
]
plugins = utils.load_plugins("dummy_group")
assert "plugin1" in plugins
assert "plugin2" in plugins
assert len(plugins) == 2
def test_messages_to_openai_spec():
# Use provided test image
png_path = "tests/test_image.png"
# Create a list of messages as input
messages = [
Message(role="user", content=[Text(text="Hello, Assistant!")]),
Message(role="assistant", content=[Text(text="Here is a text with tool usage")]),
Message(
role="tool",
content=[ToolResult(tool_use_id="1", output=f'"image:{png_path}')],
),
]
# Call the function
output = messages_to_openai_spec(messages)
assert "This tool result included an image that is uploaded in the next message." in str(output)
assert "{'role': 'user', 'content': [{'type': 'image_url'" in str(output)

View File

@@ -5,10 +5,10 @@ version = "0.9.3"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"ai-exchange",
"attrs>=23.2.0",
"rich>=13.7.1",
"ruamel-yaml>=0.18.6",
"ai-exchange>=0.9.3",
"click>=8.1.7",
"prompt-toolkit>=3.0.47",
]
@@ -53,7 +53,6 @@ dev-dependencies = [
"mkdocs-gen-files>=0.5.0",
"mkdocs-git-authors-plugin>=0.9.0",
"mkdocs-git-committers-plugin>=0.2.3",
"mkdocs-git-revision-date-localized-plugin",
"mkdocs-git-revision-date-localized-plugin>=1.2.9",
"mkdocs-glightbox>=0.4.0",
"mkdocs-include-markdown-plugin>=6.2.2",
@@ -66,3 +65,9 @@ dev-dependencies = [
"pytest-mock>=3.14.0",
"pytest>=8.3.2"
]
[tool.uv.sources]
ai-exchange = { workspace = true }
[tool.uv.workspace]
members = ["packages/*"]