Format with ruff

This commit is contained in:
David Soria Parra
2024-10-11 11:54:16 +01:00
parent 9475815241
commit fd68df6687
15 changed files with 268 additions and 101 deletions

View File

@@ -55,9 +55,11 @@ class Server:
def create_initialization_options(self) -> types.InitializationOptions:
"""Create initialization options from this server instance."""
def pkg_version(package: str) -> str:
try:
from importlib.metadata import version
return version(package)
except Exception:
return "unknown"
@@ -69,16 +71,17 @@ class Server:
)
def get_capabilities(self) -> ServerCapabilities:
"""Convert existing handlers to a ServerCapabilities object."""
def get_capability(req_type: type) -> dict[str, Any] | None:
return {} if req_type in self.request_handlers else None
"""Convert existing handlers to a ServerCapabilities object."""
return ServerCapabilities(
prompts=get_capability(ListPromptsRequest),
resources=get_capability(ListResourcesRequest),
tools=get_capability(ListPromptsRequest),
logging=get_capability(SetLevelRequest)
)
def get_capability(req_type: type) -> dict[str, Any] | None:
return {} if req_type in self.request_handlers else None
return ServerCapabilities(
prompts=get_capability(ListPromptsRequest),
resources=get_capability(ListResourcesRequest),
tools=get_capability(ListPromptsRequest),
logging=get_capability(SetLevelRequest),
)
@property
def request_context(self) -> RequestContext:
@@ -87,7 +90,7 @@ class Server:
def list_prompts(self):
def decorator(func: Callable[[], Awaitable[list[Prompt]]]):
logger.debug(f"Registering handler for PromptListRequest")
logger.debug("Registering handler for PromptListRequest")
async def handler(_: Any):
prompts = await func()
@@ -103,17 +106,19 @@ class Server:
GetPromptRequest,
GetPromptResult,
ImageContent,
Role as Role,
SamplingMessage,
TextContent,
)
from mcp_python.types import (
Role as Role,
)
def decorator(
func: Callable[
[str, dict[str, str] | None], Awaitable[types.PromptResponse]
],
):
logger.debug(f"Registering handler for GetPromptRequest")
logger.debug("Registering handler for GetPromptRequest")
async def handler(req: GetPromptRequest):
prompt_get = await func(req.params.name, req.params.arguments)
@@ -149,7 +154,7 @@ class Server:
def list_resources(self):
def decorator(func: Callable[[], Awaitable[list[Resource]]]):
logger.debug(f"Registering handler for ListResourcesRequest")
logger.debug("Registering handler for ListResourcesRequest")
async def handler(_: Any):
resources = await func()
@@ -169,7 +174,7 @@ class Server:
)
def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]):
logger.debug(f"Registering handler for ReadResourceRequest")
logger.debug("Registering handler for ReadResourceRequest")
async def handler(req: ReadResourceRequest):
result = await func(req.params.uri)
@@ -204,7 +209,7 @@ class Server:
from mcp_python.types import EmptyResult
def decorator(func: Callable[[LoggingLevel], Awaitable[None]]):
logger.debug(f"Registering handler for SetLevelRequest")
logger.debug("Registering handler for SetLevelRequest")
async def handler(req: SetLevelRequest):
await func(req.params.level)
@@ -219,7 +224,7 @@ class Server:
from mcp_python.types import EmptyResult
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
logger.debug(f"Registering handler for SubscribeRequest")
logger.debug("Registering handler for SubscribeRequest")
async def handler(req: SubscribeRequest):
await func(req.params.uri)
@@ -234,7 +239,7 @@ class Server:
from mcp_python.types import EmptyResult
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
logger.debug(f"Registering handler for UnsubscribeRequest")
logger.debug("Registering handler for UnsubscribeRequest")
async def handler(req: UnsubscribeRequest):
await func(req.params.uri)
@@ -249,7 +254,7 @@ class Server:
from mcp_python.types import CallToolResult
def decorator(func: Callable[..., Awaitable[Any]]):
logger.debug(f"Registering handler for CallToolRequest")
logger.debug("Registering handler for CallToolRequest")
async def handler(req: CallToolRequest):
result = await func(req.params.name, **(req.params.arguments or {}))
@@ -264,7 +269,7 @@ class Server:
def decorator(
func: Callable[[str | int, float, float | None], Awaitable[None]],
):
logger.debug(f"Registering handler for ProgressNotification")
logger.debug("Registering handler for ProgressNotification")
async def handler(req: ProgressNotification):
await func(
@@ -286,7 +291,7 @@ class Server:
Awaitable[Completion | None],
],
):
logger.debug(f"Registering handler for CompleteRequest")
logger.debug("Registering handler for CompleteRequest")
async def handler(req: CompleteRequest):
completion = await func(req.params.ref, req.params.argument)
@@ -307,10 +312,12 @@ class Server:
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
initialization_options: types.InitializationOptions
initialization_options: types.InitializationOptions,
):
with warnings.catch_warnings(record=True) as w:
async with ServerSession(read_stream, write_stream, initialization_options) as session:
async with ServerSession(
read_stream, write_stream, initialization_options
) as session:
async for message in session.incoming_messages:
logger.debug(f"Received message: {message}")
@@ -359,14 +366,16 @@ class Server:
handler = self.notification_handlers[type(notify)]
logger.debug(
f"Dispatching notification of type {type(notify).__name__}"
f"Dispatching notification of type "
f"{type(notify).__name__}"
)
try:
await handler(notify)
except Exception as err:
logger.error(
f"Uncaught exception in notification handler: {err}"
f"Uncaught exception in notification handler: "
f"{err}"
)
for warning in w:

View File

@@ -1,11 +1,12 @@
import importlib.metadata
import logging
import sys
import importlib.metadata
import anyio
from mcp_python.server.session import ServerSession
from mcp_python.server.types import InitializationOptions
from mcp_python.server.stdio import stdio_server
from mcp_python.server.types import InitializationOptions
from mcp_python.types import ServerCapabilities
if not sys.warnoptions:
@@ -30,7 +31,18 @@ async def receive_loop(session: ServerSession):
async def main():
version = importlib.metadata.version("mcp_python")
async with stdio_server() as (read_stream, write_stream):
async with ServerSession(read_stream, write_stream, InitializationOptions(server_name="mcp_python", server_version=version, capabilities=ServerCapabilities())) as session, write_stream:
async with (
ServerSession(
read_stream,
write_stream,
InitializationOptions(
server_name="mcp_python",
server_version=version,
capabilities=ServerCapabilities(),
),
) as session,
write_stream,
):
await receive_loop(session)

View File

@@ -6,11 +6,11 @@ import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from mcp_python.server.types import InitializationOptions
from mcp_python.shared.session import (
BaseSession,
RequestResponder,
)
from mcp_python.server.types import InitializationOptions
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
from mcp_python.types import (
ClientNotification,
@@ -25,7 +25,6 @@ from mcp_python.types import (
JSONRPCMessage,
LoggingLevel,
SamplingMessage,
ServerCapabilities,
ServerNotification,
ServerRequest,
ServerResult,
@@ -53,7 +52,7 @@ class ServerSession(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
init_options: InitializationOptions
init_options: InitializationOptions,
) -> None:
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
self._initialization_state = InitializationState.NotInitialized
@@ -72,7 +71,7 @@ class ServerSession(
capabilities=self._init_options.capabilities,
serverInfo=Implementation(
name=self._init_options.server_name,
version=self._init_options.server_version
version=self._init_options.server_version,
),
)
)

View File

@@ -19,10 +19,14 @@ logger = logging.getLogger(__name__)
class SseServerTransport:
"""
SSE server transport for MCP. This class provides _two_ ASGI applications, suitable to be used with a framework like Starlette and a server like Hypercorn:
SSE server transport for MCP. This class provides _two_ ASGI applications,
suitable to be used with a framework like Starlette and a server like Hypercorn:
1. connect_sse() is an ASGI application which receives incoming GET requests, and sets up a new SSE stream to send server messages to the client.
2. handle_post_message() is an ASGI application which receives incoming POST requests, which should contain client messages that link to a previously-established SSE session.
1. connect_sse() is an ASGI application which receives incoming GET requests,
and sets up a new SSE stream to send server messages to the client.
2. handle_post_message() is an ASGI application which receives incoming POST
requests, which should contain client messages that link to a
previously-established SSE session.
"""
_endpoint: str
@@ -30,7 +34,8 @@ class SseServerTransport:
def __init__(self, endpoint: str) -> None:
"""
Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL given.
Creates a new SSE server transport, which will direct the client to POST
messages to the relative or absolute URL given.
"""
super().__init__()
@@ -74,7 +79,9 @@ class SseServerTransport:
await sse_stream_writer.send(
{
"event": "message",
"data": message.model_dump_json(by_alias=True, exclude_none=True),
"data": message.model_dump_json(
by_alias=True, exclude_none=True
),
}
)

View File

@@ -7,14 +7,18 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
from mcp_python.types import JSONRPCMessage
@asynccontextmanager
async def stdio_server(
stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None
stdin: anyio.AsyncFile[str] | None = None,
stdout: anyio.AsyncFile[str] | None = None,
):
"""
Server transport for stdio: this communicates with an MCP client by reading from the current process' stdin and writing to stdout.
Server transport for stdio: this communicates with an MCP client by reading
from the current process' stdin and writing to stdout.
"""
# Purposely not using context managers for these, as we don't want to close standard process handles.
# Purposely not using context managers for these, as we don't want to close
# standard process handles.
if not stdin:
stdin = anyio.wrap_file(sys.stdin)
if not stdout:

View File

@@ -6,6 +6,7 @@ from dataclasses import dataclass
from typing import Literal
from pydantic import BaseModel
from mcp_python.types import Role, ServerCapabilities

View File

@@ -14,7 +14,8 @@ logger = logging.getLogger(__name__)
@asynccontextmanager
async def websocket_server(scope: Scope, receive: Receive, send: Send):
"""
WebSocket server transport for MCP. This is an ASGI application, suitable to be used with a framework like Starlette and a server like Hypercorn.
WebSocket server transport for MCP. This is an ASGI application, suitable to be
used with a framework like Starlette and a server like Hypercorn.
"""
websocket = WebSocket(scope, receive, send)
@@ -47,7 +48,9 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
try:
async with write_stream_reader:
async for message in write_stream_reader:
obj = message.model_dump(by_alias=True, mode="json", exclude_none=True)
obj = message.model_dump(
by_alias=True, mode="json", exclude_none=True
)
await websocket.send_json(obj)
except anyio.ClosedResourceError:
await websocket.close()