mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 22:54:24 +01:00
50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
import os
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from exchange import Message, Text
|
|
from exchange.providers.databricks import DatabricksProvider
|
|
|
|
|
|
@pytest.fixture
|
|
@patch.dict(
|
|
os.environ,
|
|
{"DATABRICKS_HOST": "http://test-host", "DATABRICKS_TOKEN": "test_token"},
|
|
)
|
|
def databricks_provider():
|
|
return DatabricksProvider.from_env()
|
|
|
|
|
|
@patch("httpx.Client.post")
|
|
@patch("time.sleep", return_value=None)
|
|
@patch("logging.warning")
|
|
@patch("logging.error")
|
|
def test_databricks_completion(mock_error, mock_warning, mock_sleep, mock_post, databricks_provider):
|
|
mock_response = {
|
|
"choices": [{"message": {"role": "assistant", "content": "Hello!"}}],
|
|
"usage": {"prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35},
|
|
}
|
|
mock_post.return_value.json.return_value = mock_response
|
|
|
|
model = "my-databricks-model"
|
|
system = "You are a helpful assistant."
|
|
messages = [Message.user("Hello")]
|
|
tools = ()
|
|
|
|
reply_message, reply_usage = databricks_provider.complete(
|
|
model=model, system=system, messages=messages, tools=tools
|
|
)
|
|
|
|
assert reply_message.content == [Text(text="Hello!")]
|
|
assert reply_usage.total_tokens == 35
|
|
assert mock_post.call_count == 1
|
|
mock_post.assert_called_once_with(
|
|
"serving-endpoints/my-databricks-model/invocations",
|
|
json={
|
|
"messages": [
|
|
{"role": "system", "content": system},
|
|
{"role": "user", "content": "Hello"},
|
|
]
|
|
},
|
|
)
|