Include context into completions (#966)

This commit is contained in:
Inna Harper
2025-06-17 09:33:25 +01:00
committed by GitHub
parent 7b420656de
commit a3bcabdce2
7 changed files with 363 additions and 7 deletions

View File

@@ -30,6 +30,7 @@
- [Prompts](#prompts) - [Prompts](#prompts)
- [Images](#images) - [Images](#images)
- [Context](#context) - [Context](#context)
- [Completions](#completions)
- [Running Your Server](#running-your-server) - [Running Your Server](#running-your-server)
- [Development Mode](#development-mode) - [Development Mode](#development-mode)
- [Claude Desktop Integration](#claude-desktop-integration) - [Claude Desktop Integration](#claude-desktop-integration)
@@ -310,6 +311,68 @@ async def long_task(files: list[str], ctx: Context) -> str:
return "Processing complete" return "Processing complete"
``` ```
### Completions
MCP supports providing completion suggestions for prompt arguments and resource template parameters. With the context parameter, servers can provide completions based on previously resolved values:
Client usage:
```python
from mcp.client.session import ClientSession
from mcp.types import ResourceTemplateReference
async def use_completion(session: ClientSession):
# Complete without context
result = await session.complete(
ref=ResourceTemplateReference(
type="ref/resource", uri="github://repos/{owner}/{repo}"
),
argument={"name": "owner", "value": "model"},
)
# Complete with context - repo suggestions based on owner
result = await session.complete(
ref=ResourceTemplateReference(
type="ref/resource", uri="github://repos/{owner}/{repo}"
),
argument={"name": "repo", "value": "test"},
context_arguments={"owner": "modelcontextprotocol"},
)
```
Server implementation:
```python
from mcp.server import Server
from mcp.types import (
Completion,
CompletionArgument,
CompletionContext,
PromptReference,
ResourceTemplateReference,
)
server = Server("example-server")
@server.completion()
async def handle_completion(
ref: PromptReference | ResourceTemplateReference,
argument: CompletionArgument,
context: CompletionContext | None,
) -> Completion | None:
if isinstance(ref, ResourceTemplateReference):
if ref.uri == "github://repos/{owner}/{repo}" and argument.name == "repo":
# Use context to provide owner-specific repos
if context and context.arguments:
owner = context.arguments.get("owner")
if owner == "modelcontextprotocol":
repos = ["python-sdk", "typescript-sdk", "specification"]
# Filter based on partial input
filtered = [r for r in repos if r.startswith(argument.value)]
return Completion(values=filtered)
return None
```
### Authentication ### Authentication
Authentication can be used by servers that want to expose tools accessing protected resources. Authentication can be used by servers that want to expose tools accessing protected resources.

View File

@@ -304,8 +304,13 @@ class ClientSession(
self, self,
ref: types.ResourceTemplateReference | types.PromptReference, ref: types.ResourceTemplateReference | types.PromptReference,
argument: dict[str, str], argument: dict[str, str],
context_arguments: dict[str, str] | None = None,
) -> types.CompleteResult: ) -> types.CompleteResult:
"""Send a completion/complete request.""" """Send a completion/complete request."""
context = None
if context_arguments is not None:
context = types.CompletionContext(arguments=context_arguments)
return await self.send_request( return await self.send_request(
types.ClientRequest( types.ClientRequest(
types.CompleteRequest( types.CompleteRequest(
@@ -313,6 +318,7 @@ class ClientSession(
params=types.CompleteRequestParams( params=types.CompleteRequestParams(
ref=ref, ref=ref,
argument=types.CompletionArgument(**argument), argument=types.CompletionArgument(**argument),
context=context,
), ),
) )
), ),

View File

@@ -364,6 +364,24 @@ class FastMCP:
return decorator return decorator
def completion(self):
"""Decorator to register a completion handler.
The completion handler receives:
- ref: PromptReference or ResourceTemplateReference
- argument: CompletionArgument with name and partial value
- context: Optional CompletionContext with previously resolved arguments
Example:
@mcp.completion()
async def handle_completion(ref, argument, context):
if isinstance(ref, ResourceTemplateReference):
# Return completions based on ref, argument, and context
return Completion(values=["option1", "option2"])
return None
"""
return self._mcp_server.completion()
def add_resource(self, resource: Resource) -> None: def add_resource(self, resource: Resource) -> None:
"""Add a resource to the server. """Add a resource to the server.

View File

@@ -433,6 +433,7 @@ class Server(Generic[LifespanResultT, RequestT]):
[ [
types.PromptReference | types.ResourceTemplateReference, types.PromptReference | types.ResourceTemplateReference,
types.CompletionArgument, types.CompletionArgument,
types.CompletionContext | None,
], ],
Awaitable[types.Completion | None], Awaitable[types.Completion | None],
], ],
@@ -440,7 +441,7 @@ class Server(Generic[LifespanResultT, RequestT]):
logger.debug("Registering handler for CompleteRequest") logger.debug("Registering handler for CompleteRequest")
async def handler(req: types.CompleteRequest): async def handler(req: types.CompleteRequest):
completion = await func(req.params.ref, req.params.argument) completion = await func(req.params.ref, req.params.argument, req.params.context)
return types.ServerResult( return types.ServerResult(
types.CompleteResult( types.CompleteResult(
completion=completion completion=completion

View File

@@ -1028,11 +1028,21 @@ class CompletionArgument(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class CompletionContext(BaseModel):
"""Additional, optional context for completions."""
arguments: dict[str, str] | None = None
"""Previously-resolved variables in a URI template or prompt."""
model_config = ConfigDict(extra="allow")
class CompleteRequestParams(RequestParams): class CompleteRequestParams(RequestParams):
"""Parameters for completion requests.""" """Parameters for completion requests."""
ref: ResourceTemplateReference | PromptReference ref: ResourceTemplateReference | PromptReference
argument: CompletionArgument argument: CompletionArgument
context: CompletionContext | None = None
"""Additional, optional context for completions"""
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")

View File

@@ -18,7 +18,6 @@ from pydantic import AnyUrl
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.requests import Request from starlette.requests import Request
import mcp.types as types
from mcp.client.session import ClientSession from mcp.client.session import ClientSession
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client from mcp.client.streamable_http import streamablehttp_client
@@ -26,14 +25,24 @@ from mcp.server.fastmcp import FastMCP
from mcp.server.fastmcp.resources import FunctionResource from mcp.server.fastmcp.resources import FunctionResource
from mcp.shared.context import RequestContext from mcp.shared.context import RequestContext
from mcp.types import ( from mcp.types import (
Completion,
CompletionArgument,
CompletionContext,
CreateMessageRequestParams, CreateMessageRequestParams,
CreateMessageResult, CreateMessageResult,
GetPromptResult, GetPromptResult,
InitializeResult, InitializeResult,
LoggingMessageNotification,
ProgressNotification,
PromptReference,
ReadResourceResult, ReadResourceResult,
ResourceListChangedNotification,
ResourceTemplateReference,
SamplingMessage, SamplingMessage,
ServerNotification,
TextContent, TextContent,
TextResourceContents, TextResourceContents,
ToolListChangedNotification,
) )
@@ -191,6 +200,40 @@ def make_everything_fastmcp() -> FastMCP:
# Since FastMCP doesn't support system messages in the same way # Since FastMCP doesn't support system messages in the same way
return f"Context: {context}. Query: {user_query}" return f"Context: {context}. Query: {user_query}"
# Resource template with completion support
@mcp.resource("github://repos/{owner}/{repo}")
def github_repo_resource(owner: str, repo: str) -> str:
return f"Repository: {owner}/{repo}"
# Add completion handler for the server
@mcp.completion()
async def handle_completion(
ref: PromptReference | ResourceTemplateReference,
argument: CompletionArgument,
context: CompletionContext | None,
) -> Completion | None:
# Handle GitHub repository completion
if isinstance(ref, ResourceTemplateReference):
if ref.uri == "github://repos/{owner}/{repo}" and argument.name == "repo":
if context and context.arguments and context.arguments.get("owner") == "modelcontextprotocol":
# Return repos for modelcontextprotocol org
return Completion(values=["python-sdk", "typescript-sdk", "specification"], total=3, hasMore=False)
elif context and context.arguments and context.arguments.get("owner") == "test-org":
# Return repos for test-org
return Completion(values=["test-repo1", "test-repo2"], total=2, hasMore=False)
# Handle prompt completions
if isinstance(ref, PromptReference):
if ref.name == "complex_prompt" and argument.name == "context":
# Complete context values
contexts = ["general", "technical", "business", "academic"]
return Completion(
values=[c for c in contexts if c.startswith(argument.value)], total=None, hasMore=False
)
# Default: no completion available
return Completion(values=[], total=0, hasMore=False)
# Tool that echoes request headers from context # Tool that echoes request headers from context
@mcp.tool(description="Echo request headers from context") @mcp.tool(description="Echo request headers from context")
def echo_headers(ctx: Context[Any, Any, Request]) -> str: def echo_headers(ctx: Context[Any, Any, Request]) -> str:
@@ -597,15 +640,15 @@ class NotificationCollector:
async def handle_generic_notification(self, message) -> None: async def handle_generic_notification(self, message) -> None:
# Check if this is a ServerNotification # Check if this is a ServerNotification
if isinstance(message, types.ServerNotification): if isinstance(message, ServerNotification):
# Check the specific notification type # Check the specific notification type
if isinstance(message.root, types.ProgressNotification): if isinstance(message.root, ProgressNotification):
await self.handle_progress(message.root.params) await self.handle_progress(message.root.params)
elif isinstance(message.root, types.LoggingMessageNotification): elif isinstance(message.root, LoggingMessageNotification):
await self.handle_log(message.root.params) await self.handle_log(message.root.params)
elif isinstance(message.root, types.ResourceListChangedNotification): elif isinstance(message.root, ResourceListChangedNotification):
await self.handle_resource_list_changed(message.root.params) await self.handle_resource_list_changed(message.root.params)
elif isinstance(message.root, types.ToolListChangedNotification): elif isinstance(message.root, ToolListChangedNotification):
await self.handle_tool_list_changed(message.root.params) await self.handle_tool_list_changed(message.root.params)
@@ -781,6 +824,41 @@ async def call_all_mcp_features(session: ClientSession, collector: NotificationC
if context_data["method"]: if context_data["method"]:
assert context_data["method"] == "POST" assert context_data["method"] == "POST"
# Test completion functionality
# 1. Test resource template completion with context
repo_result = await session.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"),
argument={"name": "repo", "value": ""},
context_arguments={"owner": "modelcontextprotocol"},
)
assert repo_result.completion.values == ["python-sdk", "typescript-sdk", "specification"]
assert repo_result.completion.total == 3
assert repo_result.completion.hasMore is False
# 2. Test with different context
repo_result2 = await session.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"),
argument={"name": "repo", "value": ""},
context_arguments={"owner": "test-org"},
)
assert repo_result2.completion.values == ["test-repo1", "test-repo2"]
assert repo_result2.completion.total == 2
# 3. Test prompt argument completion
context_result = await session.complete(
ref=PromptReference(type="ref/prompt", name="complex_prompt"),
argument={"name": "context", "value": "tech"},
)
assert "technical" in context_result.completion.values
# 4. Test completion without context (should return empty)
no_context_result = await session.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"),
argument={"name": "repo", "value": "test"},
)
assert no_context_result.completion.values == []
assert no_context_result.completion.total == 0
async def sampling_callback( async def sampling_callback(
context: RequestContext[ClientSession, None], context: RequestContext[ClientSession, None],

View File

@@ -0,0 +1,180 @@
"""
Tests for completion handler with context functionality.
"""
import pytest
from mcp.server.lowlevel import Server
from mcp.shared.memory import create_connected_server_and_client_session
from mcp.types import (
Completion,
CompletionArgument,
CompletionContext,
PromptReference,
ResourceTemplateReference,
)
@pytest.mark.anyio
async def test_completion_handler_receives_context():
"""Test that the completion handler receives context correctly."""
server = Server("test-server")
# Track what the handler receives
received_args = {}
@server.completion()
async def handle_completion(
ref: PromptReference | ResourceTemplateReference,
argument: CompletionArgument,
context: CompletionContext | None,
) -> Completion | None:
received_args["ref"] = ref
received_args["argument"] = argument
received_args["context"] = context
# Return test completion
return Completion(values=["test-completion"], total=1, hasMore=False)
async with create_connected_server_and_client_session(server) as client:
# Test with context
result = await client.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="test://resource/{param}"),
argument={"name": "param", "value": "test"},
context_arguments={"previous": "value"},
)
# Verify handler received the context
assert received_args["context"] is not None
assert received_args["context"].arguments == {"previous": "value"}
assert result.completion.values == ["test-completion"]
@pytest.mark.anyio
async def test_completion_backward_compatibility():
"""Test that completion works without context (backward compatibility)."""
server = Server("test-server")
context_was_none = False
@server.completion()
async def handle_completion(
ref: PromptReference | ResourceTemplateReference,
argument: CompletionArgument,
context: CompletionContext | None,
) -> Completion | None:
nonlocal context_was_none
context_was_none = context is None
return Completion(values=["no-context-completion"], total=1, hasMore=False)
async with create_connected_server_and_client_session(server) as client:
# Test without context
result = await client.complete(
ref=PromptReference(type="ref/prompt", name="test-prompt"), argument={"name": "arg", "value": "val"}
)
# Verify context was None
assert context_was_none
assert result.completion.values == ["no-context-completion"]
@pytest.mark.anyio
async def test_dependent_completion_scenario():
"""Test a real-world scenario with dependent completions."""
server = Server("test-server")
@server.completion()
async def handle_completion(
ref: PromptReference | ResourceTemplateReference,
argument: CompletionArgument,
context: CompletionContext | None,
) -> Completion | None:
# Simulate database/table completion scenario
if isinstance(ref, ResourceTemplateReference):
if ref.uri == "db://{database}/{table}":
if argument.name == "database":
# Complete database names
return Completion(values=["users_db", "products_db", "analytics_db"], total=3, hasMore=False)
elif argument.name == "table":
# Complete table names based on selected database
if context and context.arguments:
db = context.arguments.get("database")
if db == "users_db":
return Completion(values=["users", "sessions", "permissions"], total=3, hasMore=False)
elif db == "products_db":
return Completion(values=["products", "categories", "inventory"], total=3, hasMore=False)
return Completion(values=[], total=0, hasMore=False)
async with create_connected_server_and_client_session(server) as client:
# First, complete database
db_result = await client.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
argument={"name": "database", "value": ""},
)
assert "users_db" in db_result.completion.values
assert "products_db" in db_result.completion.values
# Then complete table with database context
table_result = await client.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
argument={"name": "table", "value": ""},
context_arguments={"database": "users_db"},
)
assert table_result.completion.values == ["users", "sessions", "permissions"]
# Different database gives different tables
table_result2 = await client.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
argument={"name": "table", "value": ""},
context_arguments={"database": "products_db"},
)
assert table_result2.completion.values == ["products", "categories", "inventory"]
@pytest.mark.anyio
async def test_completion_error_on_missing_context():
"""Test that server can raise error when required context is missing."""
server = Server("test-server")
@server.completion()
async def handle_completion(
ref: PromptReference | ResourceTemplateReference,
argument: CompletionArgument,
context: CompletionContext | None,
) -> Completion | None:
if isinstance(ref, ResourceTemplateReference):
if ref.uri == "db://{database}/{table}":
if argument.name == "table":
# Check if database context is provided
if not context or not context.arguments or "database" not in context.arguments:
# Raise an error instead of returning error as completion
raise ValueError("Please select a database first to see available tables")
# Normal completion if context is provided
db = context.arguments.get("database")
if db == "test_db":
return Completion(values=["users", "orders", "products"], total=3, hasMore=False)
return Completion(values=[], total=0, hasMore=False)
async with create_connected_server_and_client_session(server) as client:
# Try to complete table without database context - should raise error
with pytest.raises(Exception) as exc_info:
await client.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
argument={"name": "table", "value": ""},
)
# Verify error message
assert "Please select a database first" in str(exc_info.value)
# Now complete with proper context - should work normally
result_with_context = await client.complete(
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
argument={"name": "table", "value": ""},
context_arguments={"database": "test_db"},
)
# Should get normal completions
assert result_with_context.completion.values == ["users", "orders", "products"]