mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2026-01-09 08:54:20 +01:00
Include context into completions (#966)
This commit is contained in:
@@ -18,7 +18,6 @@ from pydantic import AnyUrl
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
@@ -26,14 +25,24 @@ from mcp.server.fastmcp import FastMCP
|
||||
from mcp.server.fastmcp.resources import FunctionResource
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.types import (
|
||||
Completion,
|
||||
CompletionArgument,
|
||||
CompletionContext,
|
||||
CreateMessageRequestParams,
|
||||
CreateMessageResult,
|
||||
GetPromptResult,
|
||||
InitializeResult,
|
||||
LoggingMessageNotification,
|
||||
ProgressNotification,
|
||||
PromptReference,
|
||||
ReadResourceResult,
|
||||
ResourceListChangedNotification,
|
||||
ResourceTemplateReference,
|
||||
SamplingMessage,
|
||||
ServerNotification,
|
||||
TextContent,
|
||||
TextResourceContents,
|
||||
ToolListChangedNotification,
|
||||
)
|
||||
|
||||
|
||||
@@ -191,6 +200,40 @@ def make_everything_fastmcp() -> FastMCP:
|
||||
# Since FastMCP doesn't support system messages in the same way
|
||||
return f"Context: {context}. Query: {user_query}"
|
||||
|
||||
# Resource template with completion support
|
||||
@mcp.resource("github://repos/{owner}/{repo}")
|
||||
def github_repo_resource(owner: str, repo: str) -> str:
|
||||
return f"Repository: {owner}/{repo}"
|
||||
|
||||
# Add completion handler for the server
|
||||
@mcp.completion()
|
||||
async def handle_completion(
|
||||
ref: PromptReference | ResourceTemplateReference,
|
||||
argument: CompletionArgument,
|
||||
context: CompletionContext | None,
|
||||
) -> Completion | None:
|
||||
# Handle GitHub repository completion
|
||||
if isinstance(ref, ResourceTemplateReference):
|
||||
if ref.uri == "github://repos/{owner}/{repo}" and argument.name == "repo":
|
||||
if context and context.arguments and context.arguments.get("owner") == "modelcontextprotocol":
|
||||
# Return repos for modelcontextprotocol org
|
||||
return Completion(values=["python-sdk", "typescript-sdk", "specification"], total=3, hasMore=False)
|
||||
elif context and context.arguments and context.arguments.get("owner") == "test-org":
|
||||
# Return repos for test-org
|
||||
return Completion(values=["test-repo1", "test-repo2"], total=2, hasMore=False)
|
||||
|
||||
# Handle prompt completions
|
||||
if isinstance(ref, PromptReference):
|
||||
if ref.name == "complex_prompt" and argument.name == "context":
|
||||
# Complete context values
|
||||
contexts = ["general", "technical", "business", "academic"]
|
||||
return Completion(
|
||||
values=[c for c in contexts if c.startswith(argument.value)], total=None, hasMore=False
|
||||
)
|
||||
|
||||
# Default: no completion available
|
||||
return Completion(values=[], total=0, hasMore=False)
|
||||
|
||||
# Tool that echoes request headers from context
|
||||
@mcp.tool(description="Echo request headers from context")
|
||||
def echo_headers(ctx: Context[Any, Any, Request]) -> str:
|
||||
@@ -597,15 +640,15 @@ class NotificationCollector:
|
||||
|
||||
async def handle_generic_notification(self, message) -> None:
|
||||
# Check if this is a ServerNotification
|
||||
if isinstance(message, types.ServerNotification):
|
||||
if isinstance(message, ServerNotification):
|
||||
# Check the specific notification type
|
||||
if isinstance(message.root, types.ProgressNotification):
|
||||
if isinstance(message.root, ProgressNotification):
|
||||
await self.handle_progress(message.root.params)
|
||||
elif isinstance(message.root, types.LoggingMessageNotification):
|
||||
elif isinstance(message.root, LoggingMessageNotification):
|
||||
await self.handle_log(message.root.params)
|
||||
elif isinstance(message.root, types.ResourceListChangedNotification):
|
||||
elif isinstance(message.root, ResourceListChangedNotification):
|
||||
await self.handle_resource_list_changed(message.root.params)
|
||||
elif isinstance(message.root, types.ToolListChangedNotification):
|
||||
elif isinstance(message.root, ToolListChangedNotification):
|
||||
await self.handle_tool_list_changed(message.root.params)
|
||||
|
||||
|
||||
@@ -781,6 +824,41 @@ async def call_all_mcp_features(session: ClientSession, collector: NotificationC
|
||||
if context_data["method"]:
|
||||
assert context_data["method"] == "POST"
|
||||
|
||||
# Test completion functionality
|
||||
# 1. Test resource template completion with context
|
||||
repo_result = await session.complete(
|
||||
ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"),
|
||||
argument={"name": "repo", "value": ""},
|
||||
context_arguments={"owner": "modelcontextprotocol"},
|
||||
)
|
||||
assert repo_result.completion.values == ["python-sdk", "typescript-sdk", "specification"]
|
||||
assert repo_result.completion.total == 3
|
||||
assert repo_result.completion.hasMore is False
|
||||
|
||||
# 2. Test with different context
|
||||
repo_result2 = await session.complete(
|
||||
ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"),
|
||||
argument={"name": "repo", "value": ""},
|
||||
context_arguments={"owner": "test-org"},
|
||||
)
|
||||
assert repo_result2.completion.values == ["test-repo1", "test-repo2"]
|
||||
assert repo_result2.completion.total == 2
|
||||
|
||||
# 3. Test prompt argument completion
|
||||
context_result = await session.complete(
|
||||
ref=PromptReference(type="ref/prompt", name="complex_prompt"),
|
||||
argument={"name": "context", "value": "tech"},
|
||||
)
|
||||
assert "technical" in context_result.completion.values
|
||||
|
||||
# 4. Test completion without context (should return empty)
|
||||
no_context_result = await session.complete(
|
||||
ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"),
|
||||
argument={"name": "repo", "value": "test"},
|
||||
)
|
||||
assert no_context_result.completion.values == []
|
||||
assert no_context_result.completion.total == 0
|
||||
|
||||
|
||||
async def sampling_callback(
|
||||
context: RequestContext[ClientSession, None],
|
||||
|
||||
Reference in New Issue
Block a user