mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 14:34:27 +01:00
Include context into completions (#966)
This commit is contained in:
63
README.md
63
README.md
@@ -30,6 +30,7 @@
|
|||||||
- [Prompts](#prompts)
|
- [Prompts](#prompts)
|
||||||
- [Images](#images)
|
- [Images](#images)
|
||||||
- [Context](#context)
|
- [Context](#context)
|
||||||
|
- [Completions](#completions)
|
||||||
- [Running Your Server](#running-your-server)
|
- [Running Your Server](#running-your-server)
|
||||||
- [Development Mode](#development-mode)
|
- [Development Mode](#development-mode)
|
||||||
- [Claude Desktop Integration](#claude-desktop-integration)
|
- [Claude Desktop Integration](#claude-desktop-integration)
|
||||||
@@ -310,6 +311,68 @@ async def long_task(files: list[str], ctx: Context) -> str:
|
|||||||
return "Processing complete"
|
return "Processing complete"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Completions
|
||||||
|
|
||||||
|
MCP supports providing completion suggestions for prompt arguments and resource template parameters. With the context parameter, servers can provide completions based on previously resolved values:
|
||||||
|
|
||||||
|
Client usage:
|
||||||
|
```python
|
||||||
|
from mcp.client.session import ClientSession
|
||||||
|
from mcp.types import ResourceTemplateReference
|
||||||
|
|
||||||
|
|
||||||
|
async def use_completion(session: ClientSession):
|
||||||
|
# Complete without context
|
||||||
|
result = await session.complete(
|
||||||
|
ref=ResourceTemplateReference(
|
||||||
|
type="ref/resource", uri="github://repos/{owner}/{repo}"
|
||||||
|
),
|
||||||
|
argument={"name": "owner", "value": "model"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Complete with context - repo suggestions based on owner
|
||||||
|
result = await session.complete(
|
||||||
|
ref=ResourceTemplateReference(
|
||||||
|
type="ref/resource", uri="github://repos/{owner}/{repo}"
|
||||||
|
),
|
||||||
|
argument={"name": "repo", "value": "test"},
|
||||||
|
context_arguments={"owner": "modelcontextprotocol"},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Server implementation:
|
||||||
|
```python
|
||||||
|
from mcp.server import Server
|
||||||
|
from mcp.types import (
|
||||||
|
Completion,
|
||||||
|
CompletionArgument,
|
||||||
|
CompletionContext,
|
||||||
|
PromptReference,
|
||||||
|
ResourceTemplateReference,
|
||||||
|
)
|
||||||
|
|
||||||
|
server = Server("example-server")
|
||||||
|
|
||||||
|
|
||||||
|
@server.completion()
|
||||||
|
async def handle_completion(
|
||||||
|
ref: PromptReference | ResourceTemplateReference,
|
||||||
|
argument: CompletionArgument,
|
||||||
|
context: CompletionContext | None,
|
||||||
|
) -> Completion | None:
|
||||||
|
if isinstance(ref, ResourceTemplateReference):
|
||||||
|
if ref.uri == "github://repos/{owner}/{repo}" and argument.name == "repo":
|
||||||
|
# Use context to provide owner-specific repos
|
||||||
|
if context and context.arguments:
|
||||||
|
owner = context.arguments.get("owner")
|
||||||
|
if owner == "modelcontextprotocol":
|
||||||
|
repos = ["python-sdk", "typescript-sdk", "specification"]
|
||||||
|
# Filter based on partial input
|
||||||
|
filtered = [r for r in repos if r.startswith(argument.value)]
|
||||||
|
return Completion(values=filtered)
|
||||||
|
return None
|
||||||
|
```
|
||||||
|
|
||||||
### Authentication
|
### Authentication
|
||||||
|
|
||||||
Authentication can be used by servers that want to expose tools accessing protected resources.
|
Authentication can be used by servers that want to expose tools accessing protected resources.
|
||||||
|
|||||||
@@ -304,8 +304,13 @@ class ClientSession(
|
|||||||
self,
|
self,
|
||||||
ref: types.ResourceTemplateReference | types.PromptReference,
|
ref: types.ResourceTemplateReference | types.PromptReference,
|
||||||
argument: dict[str, str],
|
argument: dict[str, str],
|
||||||
|
context_arguments: dict[str, str] | None = None,
|
||||||
) -> types.CompleteResult:
|
) -> types.CompleteResult:
|
||||||
"""Send a completion/complete request."""
|
"""Send a completion/complete request."""
|
||||||
|
context = None
|
||||||
|
if context_arguments is not None:
|
||||||
|
context = types.CompletionContext(arguments=context_arguments)
|
||||||
|
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
types.ClientRequest(
|
types.ClientRequest(
|
||||||
types.CompleteRequest(
|
types.CompleteRequest(
|
||||||
@@ -313,6 +318,7 @@ class ClientSession(
|
|||||||
params=types.CompleteRequestParams(
|
params=types.CompleteRequestParams(
|
||||||
ref=ref,
|
ref=ref,
|
||||||
argument=types.CompletionArgument(**argument),
|
argument=types.CompletionArgument(**argument),
|
||||||
|
context=context,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -364,6 +364,24 @@ class FastMCP:
|
|||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def completion(self):
|
||||||
|
"""Decorator to register a completion handler.
|
||||||
|
|
||||||
|
The completion handler receives:
|
||||||
|
- ref: PromptReference or ResourceTemplateReference
|
||||||
|
- argument: CompletionArgument with name and partial value
|
||||||
|
- context: Optional CompletionContext with previously resolved arguments
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@mcp.completion()
|
||||||
|
async def handle_completion(ref, argument, context):
|
||||||
|
if isinstance(ref, ResourceTemplateReference):
|
||||||
|
# Return completions based on ref, argument, and context
|
||||||
|
return Completion(values=["option1", "option2"])
|
||||||
|
return None
|
||||||
|
"""
|
||||||
|
return self._mcp_server.completion()
|
||||||
|
|
||||||
def add_resource(self, resource: Resource) -> None:
|
def add_resource(self, resource: Resource) -> None:
|
||||||
"""Add a resource to the server.
|
"""Add a resource to the server.
|
||||||
|
|
||||||
|
|||||||
@@ -433,6 +433,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
[
|
[
|
||||||
types.PromptReference | types.ResourceTemplateReference,
|
types.PromptReference | types.ResourceTemplateReference,
|
||||||
types.CompletionArgument,
|
types.CompletionArgument,
|
||||||
|
types.CompletionContext | None,
|
||||||
],
|
],
|
||||||
Awaitable[types.Completion | None],
|
Awaitable[types.Completion | None],
|
||||||
],
|
],
|
||||||
@@ -440,7 +441,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
logger.debug("Registering handler for CompleteRequest")
|
logger.debug("Registering handler for CompleteRequest")
|
||||||
|
|
||||||
async def handler(req: types.CompleteRequest):
|
async def handler(req: types.CompleteRequest):
|
||||||
completion = await func(req.params.ref, req.params.argument)
|
completion = await func(req.params.ref, req.params.argument, req.params.context)
|
||||||
return types.ServerResult(
|
return types.ServerResult(
|
||||||
types.CompleteResult(
|
types.CompleteResult(
|
||||||
completion=completion
|
completion=completion
|
||||||
|
|||||||
@@ -1028,11 +1028,21 @@ class CompletionArgument(BaseModel):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionContext(BaseModel):
|
||||||
|
"""Additional, optional context for completions."""
|
||||||
|
|
||||||
|
arguments: dict[str, str] | None = None
|
||||||
|
"""Previously-resolved variables in a URI template or prompt."""
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class CompleteRequestParams(RequestParams):
|
class CompleteRequestParams(RequestParams):
|
||||||
"""Parameters for completion requests."""
|
"""Parameters for completion requests."""
|
||||||
|
|
||||||
ref: ResourceTemplateReference | PromptReference
|
ref: ResourceTemplateReference | PromptReference
|
||||||
argument: CompletionArgument
|
argument: CompletionArgument
|
||||||
|
context: CompletionContext | None = None
|
||||||
|
"""Additional, optional context for completions"""
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from pydantic import AnyUrl
|
|||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
import mcp.types as types
|
|
||||||
from mcp.client.session import ClientSession
|
from mcp.client.session import ClientSession
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
@@ -26,14 +25,24 @@ from mcp.server.fastmcp import FastMCP
|
|||||||
from mcp.server.fastmcp.resources import FunctionResource
|
from mcp.server.fastmcp.resources import FunctionResource
|
||||||
from mcp.shared.context import RequestContext
|
from mcp.shared.context import RequestContext
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
|
Completion,
|
||||||
|
CompletionArgument,
|
||||||
|
CompletionContext,
|
||||||
CreateMessageRequestParams,
|
CreateMessageRequestParams,
|
||||||
CreateMessageResult,
|
CreateMessageResult,
|
||||||
GetPromptResult,
|
GetPromptResult,
|
||||||
InitializeResult,
|
InitializeResult,
|
||||||
|
LoggingMessageNotification,
|
||||||
|
ProgressNotification,
|
||||||
|
PromptReference,
|
||||||
ReadResourceResult,
|
ReadResourceResult,
|
||||||
|
ResourceListChangedNotification,
|
||||||
|
ResourceTemplateReference,
|
||||||
SamplingMessage,
|
SamplingMessage,
|
||||||
|
ServerNotification,
|
||||||
TextContent,
|
TextContent,
|
||||||
TextResourceContents,
|
TextResourceContents,
|
||||||
|
ToolListChangedNotification,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -191,6 +200,40 @@ def make_everything_fastmcp() -> FastMCP:
|
|||||||
# Since FastMCP doesn't support system messages in the same way
|
# Since FastMCP doesn't support system messages in the same way
|
||||||
return f"Context: {context}. Query: {user_query}"
|
return f"Context: {context}. Query: {user_query}"
|
||||||
|
|
||||||
|
# Resource template with completion support
|
||||||
|
@mcp.resource("github://repos/{owner}/{repo}")
|
||||||
|
def github_repo_resource(owner: str, repo: str) -> str:
|
||||||
|
return f"Repository: {owner}/{repo}"
|
||||||
|
|
||||||
|
# Add completion handler for the server
|
||||||
|
@mcp.completion()
|
||||||
|
async def handle_completion(
|
||||||
|
ref: PromptReference | ResourceTemplateReference,
|
||||||
|
argument: CompletionArgument,
|
||||||
|
context: CompletionContext | None,
|
||||||
|
) -> Completion | None:
|
||||||
|
# Handle GitHub repository completion
|
||||||
|
if isinstance(ref, ResourceTemplateReference):
|
||||||
|
if ref.uri == "github://repos/{owner}/{repo}" and argument.name == "repo":
|
||||||
|
if context and context.arguments and context.arguments.get("owner") == "modelcontextprotocol":
|
||||||
|
# Return repos for modelcontextprotocol org
|
||||||
|
return Completion(values=["python-sdk", "typescript-sdk", "specification"], total=3, hasMore=False)
|
||||||
|
elif context and context.arguments and context.arguments.get("owner") == "test-org":
|
||||||
|
# Return repos for test-org
|
||||||
|
return Completion(values=["test-repo1", "test-repo2"], total=2, hasMore=False)
|
||||||
|
|
||||||
|
# Handle prompt completions
|
||||||
|
if isinstance(ref, PromptReference):
|
||||||
|
if ref.name == "complex_prompt" and argument.name == "context":
|
||||||
|
# Complete context values
|
||||||
|
contexts = ["general", "technical", "business", "academic"]
|
||||||
|
return Completion(
|
||||||
|
values=[c for c in contexts if c.startswith(argument.value)], total=None, hasMore=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default: no completion available
|
||||||
|
return Completion(values=[], total=0, hasMore=False)
|
||||||
|
|
||||||
# Tool that echoes request headers from context
|
# Tool that echoes request headers from context
|
||||||
@mcp.tool(description="Echo request headers from context")
|
@mcp.tool(description="Echo request headers from context")
|
||||||
def echo_headers(ctx: Context[Any, Any, Request]) -> str:
|
def echo_headers(ctx: Context[Any, Any, Request]) -> str:
|
||||||
@@ -597,15 +640,15 @@ class NotificationCollector:
|
|||||||
|
|
||||||
async def handle_generic_notification(self, message) -> None:
|
async def handle_generic_notification(self, message) -> None:
|
||||||
# Check if this is a ServerNotification
|
# Check if this is a ServerNotification
|
||||||
if isinstance(message, types.ServerNotification):
|
if isinstance(message, ServerNotification):
|
||||||
# Check the specific notification type
|
# Check the specific notification type
|
||||||
if isinstance(message.root, types.ProgressNotification):
|
if isinstance(message.root, ProgressNotification):
|
||||||
await self.handle_progress(message.root.params)
|
await self.handle_progress(message.root.params)
|
||||||
elif isinstance(message.root, types.LoggingMessageNotification):
|
elif isinstance(message.root, LoggingMessageNotification):
|
||||||
await self.handle_log(message.root.params)
|
await self.handle_log(message.root.params)
|
||||||
elif isinstance(message.root, types.ResourceListChangedNotification):
|
elif isinstance(message.root, ResourceListChangedNotification):
|
||||||
await self.handle_resource_list_changed(message.root.params)
|
await self.handle_resource_list_changed(message.root.params)
|
||||||
elif isinstance(message.root, types.ToolListChangedNotification):
|
elif isinstance(message.root, ToolListChangedNotification):
|
||||||
await self.handle_tool_list_changed(message.root.params)
|
await self.handle_tool_list_changed(message.root.params)
|
||||||
|
|
||||||
|
|
||||||
@@ -781,6 +824,41 @@ async def call_all_mcp_features(session: ClientSession, collector: NotificationC
|
|||||||
if context_data["method"]:
|
if context_data["method"]:
|
||||||
assert context_data["method"] == "POST"
|
assert context_data["method"] == "POST"
|
||||||
|
|
||||||
|
# Test completion functionality
|
||||||
|
# 1. Test resource template completion with context
|
||||||
|
repo_result = await session.complete(
|
||||||
|
ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"),
|
||||||
|
argument={"name": "repo", "value": ""},
|
||||||
|
context_arguments={"owner": "modelcontextprotocol"},
|
||||||
|
)
|
||||||
|
assert repo_result.completion.values == ["python-sdk", "typescript-sdk", "specification"]
|
||||||
|
assert repo_result.completion.total == 3
|
||||||
|
assert repo_result.completion.hasMore is False
|
||||||
|
|
||||||
|
# 2. Test with different context
|
||||||
|
repo_result2 = await session.complete(
|
||||||
|
ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"),
|
||||||
|
argument={"name": "repo", "value": ""},
|
||||||
|
context_arguments={"owner": "test-org"},
|
||||||
|
)
|
||||||
|
assert repo_result2.completion.values == ["test-repo1", "test-repo2"]
|
||||||
|
assert repo_result2.completion.total == 2
|
||||||
|
|
||||||
|
# 3. Test prompt argument completion
|
||||||
|
context_result = await session.complete(
|
||||||
|
ref=PromptReference(type="ref/prompt", name="complex_prompt"),
|
||||||
|
argument={"name": "context", "value": "tech"},
|
||||||
|
)
|
||||||
|
assert "technical" in context_result.completion.values
|
||||||
|
|
||||||
|
# 4. Test completion without context (should return empty)
|
||||||
|
no_context_result = await session.complete(
|
||||||
|
ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"),
|
||||||
|
argument={"name": "repo", "value": "test"},
|
||||||
|
)
|
||||||
|
assert no_context_result.completion.values == []
|
||||||
|
assert no_context_result.completion.total == 0
|
||||||
|
|
||||||
|
|
||||||
async def sampling_callback(
|
async def sampling_callback(
|
||||||
context: RequestContext[ClientSession, None],
|
context: RequestContext[ClientSession, None],
|
||||||
|
|||||||
180
tests/server/test_completion_with_context.py
Normal file
180
tests/server/test_completion_with_context.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
"""
|
||||||
|
Tests for completion handler with context functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mcp.server.lowlevel import Server
|
||||||
|
from mcp.shared.memory import create_connected_server_and_client_session
|
||||||
|
from mcp.types import (
|
||||||
|
Completion,
|
||||||
|
CompletionArgument,
|
||||||
|
CompletionContext,
|
||||||
|
PromptReference,
|
||||||
|
ResourceTemplateReference,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_completion_handler_receives_context():
|
||||||
|
"""Test that the completion handler receives context correctly."""
|
||||||
|
server = Server("test-server")
|
||||||
|
|
||||||
|
# Track what the handler receives
|
||||||
|
received_args = {}
|
||||||
|
|
||||||
|
@server.completion()
|
||||||
|
async def handle_completion(
|
||||||
|
ref: PromptReference | ResourceTemplateReference,
|
||||||
|
argument: CompletionArgument,
|
||||||
|
context: CompletionContext | None,
|
||||||
|
) -> Completion | None:
|
||||||
|
received_args["ref"] = ref
|
||||||
|
received_args["argument"] = argument
|
||||||
|
received_args["context"] = context
|
||||||
|
|
||||||
|
# Return test completion
|
||||||
|
return Completion(values=["test-completion"], total=1, hasMore=False)
|
||||||
|
|
||||||
|
async with create_connected_server_and_client_session(server) as client:
|
||||||
|
# Test with context
|
||||||
|
result = await client.complete(
|
||||||
|
ref=ResourceTemplateReference(type="ref/resource", uri="test://resource/{param}"),
|
||||||
|
argument={"name": "param", "value": "test"},
|
||||||
|
context_arguments={"previous": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify handler received the context
|
||||||
|
assert received_args["context"] is not None
|
||||||
|
assert received_args["context"].arguments == {"previous": "value"}
|
||||||
|
assert result.completion.values == ["test-completion"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_completion_backward_compatibility():
|
||||||
|
"""Test that completion works without context (backward compatibility)."""
|
||||||
|
server = Server("test-server")
|
||||||
|
|
||||||
|
context_was_none = False
|
||||||
|
|
||||||
|
@server.completion()
|
||||||
|
async def handle_completion(
|
||||||
|
ref: PromptReference | ResourceTemplateReference,
|
||||||
|
argument: CompletionArgument,
|
||||||
|
context: CompletionContext | None,
|
||||||
|
) -> Completion | None:
|
||||||
|
nonlocal context_was_none
|
||||||
|
context_was_none = context is None
|
||||||
|
|
||||||
|
return Completion(values=["no-context-completion"], total=1, hasMore=False)
|
||||||
|
|
||||||
|
async with create_connected_server_and_client_session(server) as client:
|
||||||
|
# Test without context
|
||||||
|
result = await client.complete(
|
||||||
|
ref=PromptReference(type="ref/prompt", name="test-prompt"), argument={"name": "arg", "value": "val"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify context was None
|
||||||
|
assert context_was_none
|
||||||
|
assert result.completion.values == ["no-context-completion"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_dependent_completion_scenario():
|
||||||
|
"""Test a real-world scenario with dependent completions."""
|
||||||
|
server = Server("test-server")
|
||||||
|
|
||||||
|
@server.completion()
|
||||||
|
async def handle_completion(
|
||||||
|
ref: PromptReference | ResourceTemplateReference,
|
||||||
|
argument: CompletionArgument,
|
||||||
|
context: CompletionContext | None,
|
||||||
|
) -> Completion | None:
|
||||||
|
# Simulate database/table completion scenario
|
||||||
|
if isinstance(ref, ResourceTemplateReference):
|
||||||
|
if ref.uri == "db://{database}/{table}":
|
||||||
|
if argument.name == "database":
|
||||||
|
# Complete database names
|
||||||
|
return Completion(values=["users_db", "products_db", "analytics_db"], total=3, hasMore=False)
|
||||||
|
elif argument.name == "table":
|
||||||
|
# Complete table names based on selected database
|
||||||
|
if context and context.arguments:
|
||||||
|
db = context.arguments.get("database")
|
||||||
|
if db == "users_db":
|
||||||
|
return Completion(values=["users", "sessions", "permissions"], total=3, hasMore=False)
|
||||||
|
elif db == "products_db":
|
||||||
|
return Completion(values=["products", "categories", "inventory"], total=3, hasMore=False)
|
||||||
|
|
||||||
|
return Completion(values=[], total=0, hasMore=False)
|
||||||
|
|
||||||
|
async with create_connected_server_and_client_session(server) as client:
|
||||||
|
# First, complete database
|
||||||
|
db_result = await client.complete(
|
||||||
|
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
|
||||||
|
argument={"name": "database", "value": ""},
|
||||||
|
)
|
||||||
|
assert "users_db" in db_result.completion.values
|
||||||
|
assert "products_db" in db_result.completion.values
|
||||||
|
|
||||||
|
# Then complete table with database context
|
||||||
|
table_result = await client.complete(
|
||||||
|
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
|
||||||
|
argument={"name": "table", "value": ""},
|
||||||
|
context_arguments={"database": "users_db"},
|
||||||
|
)
|
||||||
|
assert table_result.completion.values == ["users", "sessions", "permissions"]
|
||||||
|
|
||||||
|
# Different database gives different tables
|
||||||
|
table_result2 = await client.complete(
|
||||||
|
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
|
||||||
|
argument={"name": "table", "value": ""},
|
||||||
|
context_arguments={"database": "products_db"},
|
||||||
|
)
|
||||||
|
assert table_result2.completion.values == ["products", "categories", "inventory"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_completion_error_on_missing_context():
|
||||||
|
"""Test that server can raise error when required context is missing."""
|
||||||
|
server = Server("test-server")
|
||||||
|
|
||||||
|
@server.completion()
|
||||||
|
async def handle_completion(
|
||||||
|
ref: PromptReference | ResourceTemplateReference,
|
||||||
|
argument: CompletionArgument,
|
||||||
|
context: CompletionContext | None,
|
||||||
|
) -> Completion | None:
|
||||||
|
if isinstance(ref, ResourceTemplateReference):
|
||||||
|
if ref.uri == "db://{database}/{table}":
|
||||||
|
if argument.name == "table":
|
||||||
|
# Check if database context is provided
|
||||||
|
if not context or not context.arguments or "database" not in context.arguments:
|
||||||
|
# Raise an error instead of returning error as completion
|
||||||
|
raise ValueError("Please select a database first to see available tables")
|
||||||
|
# Normal completion if context is provided
|
||||||
|
db = context.arguments.get("database")
|
||||||
|
if db == "test_db":
|
||||||
|
return Completion(values=["users", "orders", "products"], total=3, hasMore=False)
|
||||||
|
|
||||||
|
return Completion(values=[], total=0, hasMore=False)
|
||||||
|
|
||||||
|
async with create_connected_server_and_client_session(server) as client:
|
||||||
|
# Try to complete table without database context - should raise error
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await client.complete(
|
||||||
|
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
|
||||||
|
argument={"name": "table", "value": ""},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify error message
|
||||||
|
assert "Please select a database first" in str(exc_info.value)
|
||||||
|
|
||||||
|
# Now complete with proper context - should work normally
|
||||||
|
result_with_context = await client.complete(
|
||||||
|
ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"),
|
||||||
|
argument={"name": "table", "value": ""},
|
||||||
|
context_arguments={"database": "test_db"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should get normal completions
|
||||||
|
assert result_with_context.completion.values == ["users", "orders", "products"]
|
||||||
Reference in New Issue
Block a user