Files
goose/packages/exchange/tests/providers/test_databricks.py
2024-10-02 11:05:43 -07:00

50 lines
1.5 KiB
Python

import os
from unittest.mock import patch
import pytest
from exchange import Message, Text
from exchange.providers.databricks import DatabricksProvider
@pytest.fixture
@patch.dict(
os.environ,
{"DATABRICKS_HOST": "http://test-host", "DATABRICKS_TOKEN": "test_token"},
)
def databricks_provider():
return DatabricksProvider.from_env()
@patch("httpx.Client.post")
@patch("time.sleep", return_value=None)
@patch("logging.warning")
@patch("logging.error")
def test_databricks_completion(mock_error, mock_warning, mock_sleep, mock_post, databricks_provider):
mock_response = {
"choices": [{"message": {"role": "assistant", "content": "Hello!"}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35},
}
mock_post.return_value.json.return_value = mock_response
model = "my-databricks-model"
system = "You are a helpful assistant."
messages = [Message.user("Hello")]
tools = ()
reply_message, reply_usage = databricks_provider.complete(
model=model, system=system, messages=messages, tools=tools
)
assert reply_message.content == [Text(text="Hello!")]
assert reply_usage.total_tokens == 35
assert mock_post.call_count == 1
mock_post.assert_called_once_with(
"serving-endpoints/my-databricks-model/invocations",
json={
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": "Hello"},
]
},
)