This commit is contained in:
David Soria Parra
2024-11-11 20:17:39 +00:00
parent b9b44e6dad
commit ec8c85edea
8 changed files with 37 additions and 21 deletions

View File

@@ -3,9 +3,9 @@ from datetime import timedelta
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
import mcp.types as types
from mcp.shared.session import BaseSession
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
import mcp.types as types
class ClientSession(

View File

@@ -84,10 +84,8 @@ async def sse_client(
case "message":
try:
message = (
types.JSONRPCMessage.model_validate_json(
sse.data
)
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
)
logger.debug(
f"Received server message: {message}"

View File

@@ -7,12 +7,12 @@ from typing import Any, Sequence
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
import mcp.types as types
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext
from mcp.shared.session import RequestResponder
import mcp.types as types
logger = logging.getLogger(__name__)
@@ -36,7 +36,9 @@ class NotificationOptions:
class Server:
def __init__(self, name: str):
self.name = name
self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = {
self.request_handlers: dict[
type, Callable[..., Awaitable[types.ServerResult]]
] = {
types.PingRequest: _ping_handler,
}
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
@@ -153,7 +155,9 @@ class Server:
async def handler(_: Any):
resources = await func()
return types.ServerResult(types.ListResourcesResult(resources=resources))
return types.ServerResult(
types.ListResourcesResult(resources=resources)
)
self.request_handlers[types.ListResourcesRequest] = handler
return func
@@ -249,7 +253,11 @@ class Server:
def decorator(
func: Callable[
...,
Awaitable[Sequence[types.TextContent | types.ImageContent | types.EmbeddedResource]],
Awaitable[
Sequence[
types.TextContent | types.ImageContent | types.EmbeddedResource
]
],
],
):
logger.debug("Registering handler for CallToolRequest")
@@ -261,7 +269,9 @@ class Server:
for result in results:
match result:
case str() as text:
content.append(types.TextContent(type="text", text=text))
content.append(
types.TextContent(type="text", text=text)
)
case types.ImageContent() as img:
content.append(
types.ImageContent(
@@ -277,7 +287,9 @@ class Server:
)
)
return types.ServerResult(types.CallToolResult(content=content, isError=False))
return types.ServerResult(
types.CallToolResult(content=content, isError=False)
)
except Exception as e:
return types.ServerResult(
types.CallToolResult(
@@ -312,7 +324,10 @@ class Server:
def decorator(
func: Callable[
[types.PromptReference | types.ResourceReference, types.CompletionArgument],
[
types.PromptReference | types.ResourceReference,
types.CompletionArgument,
],
Awaitable[types.Completion | None],
],
):

View File

@@ -4,9 +4,9 @@ import sys
import anyio
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server
from mcp.server.models import InitializationOptions
from mcp.types import ServerCapabilities
if not sys.warnoptions:

View File

@@ -3,9 +3,6 @@ This module provides simpler types to use with the server for managing prompts
and tools.
"""
from dataclasses import dataclass
from typing import Literal
from pydantic import BaseModel
from mcp.types import (

View File

@@ -6,12 +6,12 @@ import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
import mcp.types as types
from mcp.server.models import InitializationOptions
from mcp.shared.session import (
BaseSession,
RequestResponder,
)
import mcp.types as types
class InitializationState(Enum):
@@ -37,7 +37,9 @@ class ServerSession(
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
init_options: InitializationOptions,
) -> None:
super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification)
super().__init__(
read_stream, write_stream, types.ClientRequest, types.ClientNotification
)
self._initialization_state = InitializationState.NotInitialized
self._init_options = init_options
@@ -65,7 +67,9 @@ class ServerSession(
"Received request before initialization was complete"
)
async def _received_notification(self, notification: types.ClientNotification) -> None:
async def _received_notification(
self, notification: types.ClientNotification
) -> None:
# Need this to avoid ASYNC910
await anyio.lowlevel.checkpoint()
match notification.root:

View File

@@ -30,7 +30,9 @@ class SseServerTransport:
"""
_endpoint: str
_read_stream_writers: dict[UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]]
_read_stream_writers: dict[
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
]
def __init__(self, endpoint: str) -> None:
"""

View File

@@ -3,8 +3,8 @@ import pytest
from mcp.client.session import ClientSession
from mcp.server import NotificationOptions, Server
from mcp.server.session import ServerSession
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.types import (
ClientNotification,
InitializedNotification,