Strict types on the client side (#285)

This commit is contained in:
Marcelo Trylesinski
2025-03-14 11:56:48 +01:00
committed by GitHub
parent 7196604468
commit 97201cce59
6 changed files with 23 additions and 14 deletions

View File

@@ -77,6 +77,7 @@ venvPath = "."
venv = ".venv" venv = ".venv"
strict = [ strict = [
"src/mcp/server/fastmcp/tools/base.py", "src/mcp/server/fastmcp/tools/base.py",
"src/mcp/client/*.py"
] ]
[tool.ruff.lint] [tool.ruff.lint]

View File

@@ -5,10 +5,12 @@ from functools import partial
from urllib.parse import urlparse from urllib.parse import urlparse
import anyio import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.client.session import ClientSession from mcp.client.session import ClientSession
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.types import JSONRPCMessage
if not sys.warnoptions: if not sys.warnoptions:
import warnings import warnings
@@ -29,7 +31,10 @@ async def receive_loop(session: ClientSession):
logger.info("Received message from server: %s", message) logger.info("Received message from server: %s", message)
async def run_session(read_stream, write_stream): async def run_session(
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
):
async with ( async with (
ClientSession(read_stream, write_stream) as session, ClientSession(read_stream, write_stream) as session,
anyio.create_task_group() as tg, anyio.create_task_group() as tg,

View File

@@ -76,18 +76,12 @@ class ClientSession(
self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback
async def initialize(self) -> types.InitializeResult: async def initialize(self) -> types.InitializeResult:
sampling = ( sampling = types.SamplingCapability()
types.SamplingCapability() if self._sampling_callback is not None else None roots = types.RootsCapability(
) # TODO: Should this be based on whether we
roots = ( # _will_ send notifications, or only whether
types.RootsCapability( # they're supported?
# TODO: Should this be based on whether we listChanged=True,
# _will_ send notifications, or only whether
# they're supported?
listChanged=True,
)
if self._list_roots_callback is not None
else None
) )
result = await self.send_request( result = await self.send_request(

View File

@@ -98,6 +98,10 @@ async def sse_client(
continue continue
await read_stream_writer.send(message) await read_stream_writer.send(message)
case _:
logger.warning(
f"Unknown SSE event: {sse.event}"
)
except Exception as exc: except Exception as exc:
logger.error(f"Error in sse_reader: {exc}") logger.error(f"Error in sse_reader: {exc}")
await read_stream_writer.send(exc) await read_stream_writer.send(exc)

View File

@@ -39,6 +39,11 @@ async def websocket_client(
# Create two in-memory streams: # Create two in-memory streams:
# - One for incoming messages (read_stream, written by ws_reader) # - One for incoming messages (read_stream, written by ws_reader)
# - One for outgoing messages (write_stream, read by ws_writer) # - One for outgoing messages (write_stream, read by ws_writer)
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0) read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

View File

@@ -1,3 +1,3 @@
from mcp.types import LATEST_PROTOCOL_VERSION from mcp.types import LATEST_PROTOCOL_VERSION
SUPPORTED_PROTOCOL_VERSIONS = [1, LATEST_PROTOCOL_VERSION] SUPPORTED_PROTOCOL_VERSIONS: tuple[int, str] = (1, LATEST_PROTOCOL_VERSION)