Update convenience methods on ClientSession and ServerSession

This commit is contained in:
Justin Spahr-Summers
2024-11-06 12:24:53 +00:00
parent 1634343931
commit 4ac03d40f9
2 changed files with 133 additions and 3 deletions

View File

@@ -1,7 +1,7 @@
from datetime import timedelta
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from pydantic import AnyUrl, FileUrl
from mcp_python.shared.session import BaseSession
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -12,14 +12,21 @@ from mcp_python.types import (
ClientNotification,
ClientRequest,
ClientResult,
CompleteResult,
EmptyResult,
GetPromptResult,
Implementation,
InitializedNotification,
InitializeResult,
JSONRPCMessage,
ListPromptsResult,
ListResourcesResult,
ListRootsResult,
ListToolsResult,
LoggingLevel,
PromptReference,
ReadResourceResult,
ResourceReference,
ServerNotification,
ServerRequest,
)
@@ -61,7 +68,12 @@ class ClientSession(
params=InitializeRequestParams(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(
sampling=None, experimental=None
sampling=None,
experimental=None,
roots={
# TODO: Should this be based on whether we _will_ send notifications, or only whether they're supported?
"listChanged": True
}
),
clientInfo=Implementation(name="mcp_python", version="0.1.0"),
),
@@ -220,3 +232,72 @@ class ClientSession(
),
CallToolResult,
)
async def list_prompts(self) -> ListPromptsResult:
"""Send a prompts/list request."""
from mcp_python.types import ListPromptsRequest
return await self.send_request(
ClientRequest(
ListPromptsRequest(
method="prompts/list",
)
),
ListPromptsResult,
)
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> GetPromptResult:
"""Send a prompts/get request."""
from mcp_python.types import GetPromptRequest, GetPromptRequestParams
return await self.send_request(
ClientRequest(
GetPromptRequest(
method="prompts/get",
params=GetPromptRequestParams(name=name, arguments=arguments),
)
),
GetPromptResult,
)
async def complete(self, ref: ResourceReference | PromptReference, argument: dict) -> CompleteResult:
"""Send a completion/complete request."""
from mcp_python.types import CompleteRequest, CompleteRequestParams, CompletionArgument
return await self.send_request(
ClientRequest(
CompleteRequest(
method="completion/complete",
params=CompleteRequestParams(
ref=ref,
argument=CompletionArgument(**argument),
),
)
),
CompleteResult,
)
async def list_tools(self) -> ListToolsResult:
"""Send a tools/list request."""
from mcp_python.types import ListToolsRequest
return await self.send_request(
ClientRequest(
ListToolsRequest(
method="tools/list",
)
),
ListToolsResult,
)
async def send_roots_list_changed(self) -> None:
"""Send a roots/list_changed notification."""
from mcp_python.types import RootsListChangedNotification
await self.send_notification(
ClientNotification(
RootsListChangedNotification(
method="notifications/roots/list_changed",
)
)
)

View File

@@ -12,7 +12,7 @@ from mcp_python.shared.session import (
RequestResponder,
)
from mcp_python.types import (
LATEST_PROTOCOL_VERSION,
ListRootsResult, LATEST_PROTOCOL_VERSION,
ClientNotification,
ClientRequest,
CreateMessageResult,
@@ -28,6 +28,10 @@ from mcp_python.types import (
ServerNotification,
ServerRequest,
ServerResult,
ResourceListChangedNotification,
ToolListChangedNotification,
PromptListChangedNotification,
ModelPreferences,
)
@@ -142,6 +146,7 @@ class ServerSession(
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: ModelPreferences | None = None,
) -> CreateMessageResult:
"""Send a sampling/create_message request."""
from mcp_python.types import (
@@ -161,12 +166,26 @@ class ServerSession(
maxTokens=max_tokens,
stopSequences=stop_sequences,
metadata=metadata,
modelPreferences=model_preferences,
),
)
),
CreateMessageResult,
)
async def list_roots(self) -> ListRootsResult:
"""Send a roots/list request."""
from mcp_python.types import ListRootsRequest
return await self.send_request(
ServerRequest(
ListRootsRequest(
method="roots/list",
)
),
ListRootsResult,
)
async def send_ping(self) -> EmptyResult:
"""Send a ping request."""
from mcp_python.types import PingRequest
@@ -198,3 +217,33 @@ class ServerSession(
)
)
)
async def send_resource_list_changed(self) -> None:
"""Send a resource list changed notification."""
await self.send_notification(
ServerNotification(
ResourceListChangedNotification(
method="notifications/resources/list_changed",
)
)
)
async def send_tool_list_changed(self) -> None:
"""Send a tool list changed notification."""
await self.send_notification(
ServerNotification(
ToolListChangedNotification(
method="notifications/tools/list_changed",
)
)
)
async def send_prompt_list_changed(self) -> None:
"""Send a prompt list changed notification."""
await self.send_notification(
ServerNotification(
PromptListChangedNotification(
method="notifications/prompts/list_changed",
)
)
)