Update convenience methods on ClientSession and ServerSession

This commit is contained in:
Justin Spahr-Summers
2024-11-06 12:24:53 +00:00
parent 1634343931
commit 4ac03d40f9
2 changed files with 133 additions and 3 deletions

View File

@@ -1,7 +1,7 @@
from datetime import timedelta from datetime import timedelta
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl from pydantic import AnyUrl, FileUrl
from mcp_python.shared.session import BaseSession from mcp_python.shared.session import BaseSession
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -12,14 +12,21 @@ from mcp_python.types import (
ClientNotification, ClientNotification,
ClientRequest, ClientRequest,
ClientResult, ClientResult,
CompleteResult,
EmptyResult, EmptyResult,
GetPromptResult,
Implementation, Implementation,
InitializedNotification, InitializedNotification,
InitializeResult, InitializeResult,
JSONRPCMessage, JSONRPCMessage,
ListPromptsResult,
ListResourcesResult, ListResourcesResult,
ListRootsResult,
ListToolsResult,
LoggingLevel, LoggingLevel,
PromptReference,
ReadResourceResult, ReadResourceResult,
ResourceReference,
ServerNotification, ServerNotification,
ServerRequest, ServerRequest,
) )
@@ -61,7 +68,12 @@ class ClientSession(
params=InitializeRequestParams( params=InitializeRequestParams(
protocolVersion=LATEST_PROTOCOL_VERSION, protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities( capabilities=ClientCapabilities(
sampling=None, experimental=None sampling=None,
experimental=None,
roots={
# TODO: Should this be based on whether we _will_ send notifications, or only whether they're supported?
"listChanged": True
}
), ),
clientInfo=Implementation(name="mcp_python", version="0.1.0"), clientInfo=Implementation(name="mcp_python", version="0.1.0"),
), ),
@@ -220,3 +232,72 @@ class ClientSession(
), ),
CallToolResult, CallToolResult,
) )
async def list_prompts(self) -> ListPromptsResult:
"""Send a prompts/list request."""
from mcp_python.types import ListPromptsRequest
return await self.send_request(
ClientRequest(
ListPromptsRequest(
method="prompts/list",
)
),
ListPromptsResult,
)
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> GetPromptResult:
"""Send a prompts/get request."""
from mcp_python.types import GetPromptRequest, GetPromptRequestParams
return await self.send_request(
ClientRequest(
GetPromptRequest(
method="prompts/get",
params=GetPromptRequestParams(name=name, arguments=arguments),
)
),
GetPromptResult,
)
async def complete(self, ref: ResourceReference | PromptReference, argument: dict) -> CompleteResult:
"""Send a completion/complete request."""
from mcp_python.types import CompleteRequest, CompleteRequestParams, CompletionArgument
return await self.send_request(
ClientRequest(
CompleteRequest(
method="completion/complete",
params=CompleteRequestParams(
ref=ref,
argument=CompletionArgument(**argument),
),
)
),
CompleteResult,
)
async def list_tools(self) -> ListToolsResult:
"""Send a tools/list request."""
from mcp_python.types import ListToolsRequest
return await self.send_request(
ClientRequest(
ListToolsRequest(
method="tools/list",
)
),
ListToolsResult,
)
async def send_roots_list_changed(self) -> None:
"""Send a roots/list_changed notification."""
from mcp_python.types import RootsListChangedNotification
await self.send_notification(
ClientNotification(
RootsListChangedNotification(
method="notifications/roots/list_changed",
)
)
)

View File

@@ -12,7 +12,7 @@ from mcp_python.shared.session import (
RequestResponder, RequestResponder,
) )
from mcp_python.types import ( from mcp_python.types import (
LATEST_PROTOCOL_VERSION, ListRootsResult, LATEST_PROTOCOL_VERSION,
ClientNotification, ClientNotification,
ClientRequest, ClientRequest,
CreateMessageResult, CreateMessageResult,
@@ -28,6 +28,10 @@ from mcp_python.types import (
ServerNotification, ServerNotification,
ServerRequest, ServerRequest,
ServerResult, ServerResult,
ResourceListChangedNotification,
ToolListChangedNotification,
PromptListChangedNotification,
ModelPreferences,
) )
@@ -142,6 +146,7 @@ class ServerSession(
temperature: float | None = None, temperature: float | None = None,
stop_sequences: list[str] | None = None, stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
model_preferences: ModelPreferences | None = None,
) -> CreateMessageResult: ) -> CreateMessageResult:
"""Send a sampling/create_message request.""" """Send a sampling/create_message request."""
from mcp_python.types import ( from mcp_python.types import (
@@ -161,12 +166,26 @@ class ServerSession(
maxTokens=max_tokens, maxTokens=max_tokens,
stopSequences=stop_sequences, stopSequences=stop_sequences,
metadata=metadata, metadata=metadata,
modelPreferences=model_preferences,
), ),
) )
), ),
CreateMessageResult, CreateMessageResult,
) )
async def list_roots(self) -> ListRootsResult:
"""Send a roots/list request."""
from mcp_python.types import ListRootsRequest
return await self.send_request(
ServerRequest(
ListRootsRequest(
method="roots/list",
)
),
ListRootsResult,
)
async def send_ping(self) -> EmptyResult: async def send_ping(self) -> EmptyResult:
"""Send a ping request.""" """Send a ping request."""
from mcp_python.types import PingRequest from mcp_python.types import PingRequest
@@ -198,3 +217,33 @@ class ServerSession(
) )
) )
) )
async def send_resource_list_changed(self) -> None:
"""Send a resource list changed notification."""
await self.send_notification(
ServerNotification(
ResourceListChangedNotification(
method="notifications/resources/list_changed",
)
)
)
async def send_tool_list_changed(self) -> None:
"""Send a tool list changed notification."""
await self.send_notification(
ServerNotification(
ToolListChangedNotification(
method="notifications/tools/list_changed",
)
)
)
async def send_prompt_list_changed(self) -> None:
"""Send a prompt list changed notification."""
await self.send_notification(
ServerNotification(
PromptListChangedNotification(
method="notifications/prompts/list_changed",
)
)
)