mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 14:34:27 +01:00
181 lines
7.3 KiB
Python
181 lines
7.3 KiB
Python
"""
|
|
Tests for completion handler with context functionality.
|
|
"""
|
|
|
|
import pytest
|
|
|
|
from mcp.server.lowlevel import Server
|
|
from mcp.shared.memory import create_connected_server_and_client_session
|
|
from mcp.types import (
|
|
Completion,
|
|
CompletionArgument,
|
|
CompletionContext,
|
|
PromptReference,
|
|
ResourceTemplateReference,
|
|
)
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_completion_handler_receives_context():
|
|
"""Test that the completion handler receives context correctly."""
|
|
server = Server("test-server")
|
|
|
|
# Track what the handler receives
|
|
received_args = {}
|
|
|
|
@server.completion()
|
|
async def handle_completion(
|
|
ref: PromptReference | ResourceTemplateReference,
|
|
argument: CompletionArgument,
|
|
context: CompletionContext | None,
|
|
) -> Completion | None:
|
|
received_args["ref"] = ref
|
|
received_args["argument"] = argument
|
|
received_args["context"] = context
|
|
|
|
# Return test completion
|
|
return Completion(values=["test-completion"], total=1, hasMore=False)
|
|
|
|
async with create_connected_server_and_client_session(server) as client:
|
|
# Test with context
|
|
result = await client.complete(
|
|
ref=ResourceTemplateReference(type="ref/resource", uri="test://resource/{param}"),
|
|
argument={"name": "param", "value": "test"},
|
|
context_arguments={"previous": "value"},
|
|
)
|
|
|
|
# Verify handler received the context
|
|
assert received_args["context"] is not None
|
|
assert received_args["context"].arguments == {"previous": "value"}
|
|
assert result.completion.values == ["test-completion"]
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_completion_backward_compatibility():
|
|
"""Test that completion works without context (backward compatibility)."""
|
|
server = Server("test-server")
|
|
|
|
context_was_none = False
|
|
|
|
@server.completion()
|
|
async def handle_completion(
|
|
ref: PromptReference | ResourceTemplateReference,
|
|
argument: CompletionArgument,
|
|
context: CompletionContext | None,
|
|
) -> Completion | None:
|
|
nonlocal context_was_none
|
|
context_was_none = context is None
|
|
|
|
return Completion(values=["no-context-completion"], total=1, hasMore=False)
|
|
|
|
async with create_connected_server_and_client_session(server) as client:
|
|
# Test without context
|
|
result = await client.complete(
|
|
ref=PromptReference(type="ref/prompt", name="test-prompt"), argument={"name": "arg", "value": "val"}
|
|
)
|
|
|
|
# Verify context was None
|
|
assert context_was_none
|
|
assert result.completion.values == ["no-context-completion"]
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_dependent_completion_scenario():
|
|
"""Test a real-world scenario with dependent completions."""
|
|
server = Server("test-server")
|
|
|
|
@server.completion()
|
|
async def handle_completion(
|
|
ref: PromptReference | ResourceTemplateReference,
|
|
argument: CompletionArgument,
|
|
context: CompletionContext | None,
|
|
) -> Completion | None:
|
|
# Simulate database/table completion scenario
|
|
if isinstance(ref, ResourceTemplateReference):
|
|
if ref.uri == "db://{database}/{table}":
|
|
if argument.name == "database":
|
|
# Complete database names
|
|
return Completion(values=["users_db", "products_db", "analytics_db"], total=3, hasMore=False)
|
|
elif argument.name == "table":
|
|
# Complete table names based on selected database
|
|
if context and context.arguments:
|
|
db = context.arguments.get("database")
|
|
if db == "users_db":
|
|
return Completion(values=["users", "sessions", "permissions"], total=3, hasMore=False)
|
|
elif db == "products_db":
|
|
return Completion(values=["products", "categories", "inventory"], total=3, hasMore=False)
|
|
|
|
return Completion(values=[], total=0, hasMore=False)
|
|
|
|
async with create_connected_server_and_client_session(server) as client:
|
|
# First, complete database
|
|
db_result = await client.complete(
|
|
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
|
|
argument={"name": "database", "value": ""},
|
|
)
|
|
assert "users_db" in db_result.completion.values
|
|
assert "products_db" in db_result.completion.values
|
|
|
|
# Then complete table with database context
|
|
table_result = await client.complete(
|
|
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
|
|
argument={"name": "table", "value": ""},
|
|
context_arguments={"database": "users_db"},
|
|
)
|
|
assert table_result.completion.values == ["users", "sessions", "permissions"]
|
|
|
|
# Different database gives different tables
|
|
table_result2 = await client.complete(
|
|
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
|
|
argument={"name": "table", "value": ""},
|
|
context_arguments={"database": "products_db"},
|
|
)
|
|
assert table_result2.completion.values == ["products", "categories", "inventory"]
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_completion_error_on_missing_context():
|
|
"""Test that server can raise error when required context is missing."""
|
|
server = Server("test-server")
|
|
|
|
@server.completion()
|
|
async def handle_completion(
|
|
ref: PromptReference | ResourceTemplateReference,
|
|
argument: CompletionArgument,
|
|
context: CompletionContext | None,
|
|
) -> Completion | None:
|
|
if isinstance(ref, ResourceTemplateReference):
|
|
if ref.uri == "db://{database}/{table}":
|
|
if argument.name == "table":
|
|
# Check if database context is provided
|
|
if not context or not context.arguments or "database" not in context.arguments:
|
|
# Raise an error instead of returning error as completion
|
|
raise ValueError("Please select a database first to see available tables")
|
|
# Normal completion if context is provided
|
|
db = context.arguments.get("database")
|
|
if db == "test_db":
|
|
return Completion(values=["users", "orders", "products"], total=3, hasMore=False)
|
|
|
|
return Completion(values=[], total=0, hasMore=False)
|
|
|
|
async with create_connected_server_and_client_session(server) as client:
|
|
# Try to complete table without database context - should raise error
|
|
with pytest.raises(Exception) as exc_info:
|
|
await client.complete(
|
|
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
|
|
argument={"name": "table", "value": ""},
|
|
)
|
|
|
|
# Verify error message
|
|
assert "Please select a database first" in str(exc_info.value)
|
|
|
|
# Now complete with proper context - should work normally
|
|
result_with_context = await client.complete(
|
|
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
|
|
argument={"name": "table", "value": ""},
|
|
context_arguments={"database": "test_db"},
|
|
)
|
|
|
|
# Should get normal completions
|
|
assert result_with_context.completion.values == ["users", "orders", "products"]
|