Include context into completions (#966)

This commit is contained in:
Inna Harper
2025-06-17 09:33:25 +01:00
committed by GitHub
parent 7b420656de
commit a3bcabdce2
7 changed files with 363 additions and 7 deletions

View File

@@ -18,7 +18,6 @@ from pydantic import AnyUrl
from starlette.applications import Starlette
from starlette.requests import Request
import mcp.types as types
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
@@ -26,14 +25,24 @@ from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.resources import FunctionResource
from mcp.shared.context import RequestContext
from mcp.types import (
Completion,
CompletionArgument,
CompletionContext,
CreateMessageRequestParams,
CreateMessageResult,
GetPromptResult,
InitializeResult,
LoggingMessageNotification,
ProgressNotification,
PromptReference,
ReadResourceResult,
ResourceListChangedNotification,
ResourceTemplateReference,
SamplingMessage,
ServerNotification,
TextContent,
TextResourceContents,
ToolListChangedNotification,
)
@@ -191,6 +200,40 @@ def make_everything_fastmcp() -> FastMCP:
# Since FastMCP doesn't support system messages in the same way
return f"Context: {context}. Query: {user_query}"
# Resource template with completion support
@mcp.resource("github://repos/{owner}/{repo}")
def github_repo_resource(owner: str, repo: str) -> str:
return f"Repository: {owner}/{repo}"
# Add completion handler for the server
@mcp.completion()
async def handle_completion(
ref: PromptReference | ResourceTemplateReference,
argument: CompletionArgument,
context: CompletionContext | None,
) -> Completion | None:
# Handle GitHub repository completion
if isinstance(ref, ResourceTemplateReference):
if ref.uri == "github://repos/{owner}/{repo}" and argument.name == "repo":
if context and context.arguments and context.arguments.get("owner") == "modelcontextprotocol":
# Return repos for modelcontextprotocol org
return Completion(values=["python-sdk", "typescript-sdk", "specification"], total=3, hasMore=False)
elif context and context.arguments and context.arguments.get("owner") == "test-org":
# Return repos for test-org
return Completion(values=["test-repo1", "test-repo2"], total=2, hasMore=False)
# Handle prompt completions
if isinstance(ref, PromptReference):
if ref.name == "complex_prompt" and argument.name == "context":
# Complete context values
contexts = ["general", "technical", "business", "academic"]
return Completion(
values=[c for c in contexts if c.startswith(argument.value)], total=None, hasMore=False
)
# Default: no completion available
return Completion(values=[], total=0, hasMore=False)
# Tool that echoes request headers from context
@mcp.tool(description="Echo request headers from context")
def echo_headers(ctx: Context[Any, Any, Request]) -> str:
@@ -597,15 +640,15 @@ class NotificationCollector:
async def handle_generic_notification(self, message) -> None:
# Check if this is a ServerNotification
if isinstance(message, types.ServerNotification):
if isinstance(message, ServerNotification):
# Check the specific notification type
if isinstance(message.root, types.ProgressNotification):
if isinstance(message.root, ProgressNotification):
await self.handle_progress(message.root.params)
elif isinstance(message.root, types.LoggingMessageNotification):
elif isinstance(message.root, LoggingMessageNotification):
await self.handle_log(message.root.params)
elif isinstance(message.root, types.ResourceListChangedNotification):
elif isinstance(message.root, ResourceListChangedNotification):
await self.handle_resource_list_changed(message.root.params)
elif isinstance(message.root, types.ToolListChangedNotification):
elif isinstance(message.root, ToolListChangedNotification):
await self.handle_tool_list_changed(message.root.params)
@@ -781,6 +824,41 @@ async def call_all_mcp_features(session: ClientSession, collector: NotificationC
if context_data["method"]:
assert context_data["method"] == "POST"
# Test completion functionality
# 1. Test resource template completion with context
repo_result = await session.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"),
argument={"name": "repo", "value": ""},
context_arguments={"owner": "modelcontextprotocol"},
)
assert repo_result.completion.values == ["python-sdk", "typescript-sdk", "specification"]
assert repo_result.completion.total == 3
assert repo_result.completion.hasMore is False
# 2. Test with different context
repo_result2 = await session.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"),
argument={"name": "repo", "value": ""},
context_arguments={"owner": "test-org"},
)
assert repo_result2.completion.values == ["test-repo1", "test-repo2"]
assert repo_result2.completion.total == 2
# 3. Test prompt argument completion
context_result = await session.complete(
ref=PromptReference(type="ref/prompt", name="complex_prompt"),
argument={"name": "context", "value": "tech"},
)
assert "technical" in context_result.completion.values
# 4. Test completion without context (should return empty)
no_context_result = await session.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"),
argument={"name": "repo", "value": "test"},
)
assert no_context_result.completion.values == []
assert no_context_result.completion.total == 0
async def sampling_callback(
context: RequestContext[ClientSession, None],

View File

@@ -0,0 +1,180 @@
"""
Tests for completion handler with context functionality.
"""
import pytest
from mcp.server.lowlevel import Server
from mcp.shared.memory import create_connected_server_and_client_session
from mcp.types import (
Completion,
CompletionArgument,
CompletionContext,
PromptReference,
ResourceTemplateReference,
)
@pytest.mark.anyio
async def test_completion_handler_receives_context():
"""Test that the completion handler receives context correctly."""
server = Server("test-server")
# Track what the handler receives
received_args = {}
@server.completion()
async def handle_completion(
ref: PromptReference | ResourceTemplateReference,
argument: CompletionArgument,
context: CompletionContext | None,
) -> Completion | None:
received_args["ref"] = ref
received_args["argument"] = argument
received_args["context"] = context
# Return test completion
return Completion(values=["test-completion"], total=1, hasMore=False)
async with create_connected_server_and_client_session(server) as client:
# Test with context
result = await client.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="test://resource/{param}"),
argument={"name": "param", "value": "test"},
context_arguments={"previous": "value"},
)
# Verify handler received the context
assert received_args["context"] is not None
assert received_args["context"].arguments == {"previous": "value"}
assert result.completion.values == ["test-completion"]
@pytest.mark.anyio
async def test_completion_backward_compatibility():
"""Test that completion works without context (backward compatibility)."""
server = Server("test-server")
context_was_none = False
@server.completion()
async def handle_completion(
ref: PromptReference | ResourceTemplateReference,
argument: CompletionArgument,
context: CompletionContext | None,
) -> Completion | None:
nonlocal context_was_none
context_was_none = context is None
return Completion(values=["no-context-completion"], total=1, hasMore=False)
async with create_connected_server_and_client_session(server) as client:
# Test without context
result = await client.complete(
ref=PromptReference(type="ref/prompt", name="test-prompt"), argument={"name": "arg", "value": "val"}
)
# Verify context was None
assert context_was_none
assert result.completion.values == ["no-context-completion"]
@pytest.mark.anyio
async def test_dependent_completion_scenario():
"""Test a real-world scenario with dependent completions."""
server = Server("test-server")
@server.completion()
async def handle_completion(
ref: PromptReference | ResourceTemplateReference,
argument: CompletionArgument,
context: CompletionContext | None,
) -> Completion | None:
# Simulate database/table completion scenario
if isinstance(ref, ResourceTemplateReference):
if ref.uri == "db://{database}/{table}":
if argument.name == "database":
# Complete database names
return Completion(values=["users_db", "products_db", "analytics_db"], total=3, hasMore=False)
elif argument.name == "table":
# Complete table names based on selected database
if context and context.arguments:
db = context.arguments.get("database")
if db == "users_db":
return Completion(values=["users", "sessions", "permissions"], total=3, hasMore=False)
elif db == "products_db":
return Completion(values=["products", "categories", "inventory"], total=3, hasMore=False)
return Completion(values=[], total=0, hasMore=False)
async with create_connected_server_and_client_session(server) as client:
# First, complete database
db_result = await client.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
argument={"name": "database", "value": ""},
)
assert "users_db" in db_result.completion.values
assert "products_db" in db_result.completion.values
# Then complete table with database context
table_result = await client.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
argument={"name": "table", "value": ""},
context_arguments={"database": "users_db"},
)
assert table_result.completion.values == ["users", "sessions", "permissions"]
# Different database gives different tables
table_result2 = await client.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
argument={"name": "table", "value": ""},
context_arguments={"database": "products_db"},
)
assert table_result2.completion.values == ["products", "categories", "inventory"]
@pytest.mark.anyio
async def test_completion_error_on_missing_context():
"""Test that server can raise error when required context is missing."""
server = Server("test-server")
@server.completion()
async def handle_completion(
ref: PromptReference | ResourceTemplateReference,
argument: CompletionArgument,
context: CompletionContext | None,
) -> Completion | None:
if isinstance(ref, ResourceTemplateReference):
if ref.uri == "db://{database}/{table}":
if argument.name == "table":
# Check if database context is provided
if not context or not context.arguments or "database" not in context.arguments:
# Raise an error instead of returning error as completion
raise ValueError("Please select a database first to see available tables")
# Normal completion if context is provided
db = context.arguments.get("database")
if db == "test_db":
return Completion(values=["users", "orders", "products"], total=3, hasMore=False)
return Completion(values=[], total=0, hasMore=False)
async with create_connected_server_and_client_session(server) as client:
# Try to complete table without database context - should raise error
with pytest.raises(Exception) as exc_info:
await client.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
argument={"name": "table", "value": ""},
)
# Verify error message
assert "Please select a database first" in str(exc_info.value)
# Now complete with proper context - should work normally
result_with_context = await client.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
argument={"name": "table", "value": ""},
context_arguments={"database": "test_db"},
)
# Should get normal completions
assert result_with_context.completion.values == ["users", "orders", "products"]