Strict types on the client side (#285)

This commit is contained in:
Marcelo Trylesinski
2025-03-14 11:56:48 +01:00
committed by GitHub
parent 7196604468
commit 97201cce59
6 changed files with 23 additions and 14 deletions

View File

@@ -77,6 +77,7 @@ venvPath = "."
venv = ".venv"
strict = [
"src/mcp/server/fastmcp/tools/base.py",
"src/mcp/client/*.py"
]
[tool.ruff.lint]

View File

@@ -5,10 +5,12 @@ from functools import partial
from urllib.parse import urlparse
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.types import JSONRPCMessage
if not sys.warnoptions:
import warnings
@@ -29,7 +31,10 @@ async def receive_loop(session: ClientSession):
logger.info("Received message from server: %s", message)
async def run_session(read_stream, write_stream):
async def run_session(
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
):
async with (
ClientSession(read_stream, write_stream) as session,
anyio.create_task_group() as tg,

View File

@@ -76,19 +76,13 @@ class ClientSession(
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
async def initialize(self) -> types.InitializeResult:
sampling = (
types.SamplingCapability() if self._sampling_callback is not None else None
)
roots = (
types.RootsCapability(
sampling = types.SamplingCapability()
roots = types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True,
)
if self._list_roots_callback is not None
else None
)
result = await self.send_request(
types.ClientRequest(

View File

@@ -98,6 +98,10 @@ async def sse_client(
continue
await read_stream_writer.send(message)
case _:
logger.warning(
f"Unknown SSE event: {sse.event}"
)
except Exception as exc:
logger.error(f"Error in sse_reader: {exc}")
await read_stream_writer.send(exc)

View File

@@ -39,6 +39,11 @@ async def websocket_client(
# Create two in-memory streams:
# - One for incoming messages (read_stream, written by ws_reader)
# - One for outgoing messages (write_stream, read by ws_writer)
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

View File

@@ -1,3 +1,3 @@
from mcp.types import LATEST_PROTOCOL_VERSION
SUPPORTED_PROTOCOL_VERSIONS = [1, LATEST_PROTOCOL_VERSION]
SUPPORTED_PROTOCOL_VERSIONS: tuple[int, str] = (1, LATEST_PROTOCOL_VERSION)