Add in-memory transport (#25)

## Goal

Support running an MCP server in the same process as the client, while preserving MCP abstractions.

## Details

1. **(core change)** Adds a new `memory` transport module that enables in-process client-server communication.
This includes:
   - `create_client_server_memory_streams()` to create bidirectional memory streams
   - `create_connected_server_and_client_session()` to establish an in-process client-server connection

3. (minor) Enhances error handling and timeout support:
   - Adds configurable read timeouts to sessions via `read_timeout_seconds` parameter
   - Improves exception handling in the server with a new `raise_exceptions` flag to control whether errors are returned to clients or raised directly
   - Ensures proper cleanup of request context tokens in error cases

4. (minor) Makes server improvements:
   - Adds built-in ping handler support
This commit is contained in:
Nick Merrill
2024-11-05 18:42:41 -05:00
committed by GitHub
parent 1a60e1b7c7
commit 60e9c7a0d7
9 changed files with 210 additions and 10 deletions

View File

@@ -1,3 +1,5 @@
from datetime import timedelta
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl from pydantic import AnyUrl
@@ -36,8 +38,15 @@ class ClientSession(
self, self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage], write_stream: MemoryObjectSendStream[JSONRPCMessage],
read_timeout_seconds: timedelta | None = None,
) -> None: ) -> None:
super().__init__(read_stream, write_stream, ServerRequest, ServerNotification) super().__init__(
read_stream,
write_stream,
ServerRequest,
ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
async def initialize(self) -> InitializeResult: async def initialize(self) -> InitializeResult:
from mcp_python.types import ( from mcp_python.types import (

View File

@@ -18,6 +18,7 @@ from mcp_python.types import (
ClientNotification, ClientNotification,
ClientRequest, ClientRequest,
CompleteRequest, CompleteRequest,
EmptyResult,
ErrorData, ErrorData,
JSONRPCMessage, JSONRPCMessage,
ListPromptsRequest, ListPromptsRequest,
@@ -27,6 +28,7 @@ from mcp_python.types import (
ListToolsRequest, ListToolsRequest,
ListToolsResult, ListToolsResult,
LoggingLevel, LoggingLevel,
PingRequest,
ProgressNotification, ProgressNotification,
Prompt, Prompt,
PromptReference, PromptReference,
@@ -52,9 +54,11 @@ request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
class Server: class Server:
def __init__(self, name: str): def __init__(self, name: str):
self.name = name self.name = name
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {} self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {
PingRequest: _ping_handler,
}
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
logger.info(f"Initializing server '{name}'") logger.debug(f"Initializing server '{name}'")
def create_initialization_options(self) -> types.InitializationOptions: def create_initialization_options(self) -> types.InitializationOptions:
"""Create initialization options from this server instance.""" """Create initialization options from this server instance."""
@@ -63,9 +67,13 @@ class Server:
try: try:
from importlib.metadata import version from importlib.metadata import version
return version(package) v = version(package)
if v is not None:
return v
except Exception: except Exception:
return "unknown" pass
return "unknown"
return types.InitializationOptions( return types.InitializationOptions(
server_name=self.name, server_name=self.name,
@@ -330,6 +338,11 @@ class Server:
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage], write_stream: MemoryObjectSendStream[JSONRPCMessage],
initialization_options: types.InitializationOptions, initialization_options: types.InitializationOptions,
# When True, exceptions are returned as messages to the client.
# When False, exceptions are raised, which will cause the server to shut down
# but also make tracing exceptions much easier during testing and when using
# in-process servers.
raise_exceptions: bool = False,
): ):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
async with ServerSession( async with ServerSession(
@@ -349,6 +362,7 @@ class Server:
f"Dispatching request of type {type(req).__name__}" f"Dispatching request of type {type(req).__name__}"
) )
token = None
try: try:
# Set our global state that can be retrieved via # Set our global state that can be retrieved via
# app.get_request_context() # app.get_request_context()
@@ -360,12 +374,16 @@ class Server:
) )
) )
response = await handler(req) response = await handler(req)
# Reset the global state after we are done
request_ctx.reset(token)
except Exception as err: except Exception as err:
if raise_exceptions:
raise err
response = ErrorData( response = ErrorData(
code=0, message=str(err), data=None code=0, message=str(err), data=None
) )
finally:
# Reset the global state after we are done
if token is not None:
request_ctx.reset(token)
await message.respond(response) await message.respond(response)
else: else:
@@ -399,3 +417,7 @@ class Server:
logger.info( logger.info(
f"Warning: {warning.category.__name__}: {warning.message}" f"Warning: {warning.category.__name__}: {warning.message}"
) )
async def _ping_handler(request: PingRequest) -> ServerResult:
return ServerResult(EmptyResult())

View File

@@ -0,0 +1,87 @@
"""
In-memory transports
"""
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import AsyncGenerator
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp_python.client.session import ClientSession
from mcp_python.server import Server
from mcp_python.types import JSONRPCMessage
MessageStream = tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage]
]
@asynccontextmanager
async def create_client_server_memory_streams() -> AsyncGenerator[
tuple[MessageStream, MessageStream],
None
]:
"""
Creates a pair of bidirectional memory streams for client-server communication.
Returns:
A tuple of (client_streams, server_streams) where each is a tuple of
(read_stream, write_stream)
"""
# Create streams for both directions
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
JSONRPCMessage | Exception
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
JSONRPCMessage | Exception
](1)
client_streams = (server_to_client_receive, client_to_server_send)
server_streams = (client_to_server_receive, server_to_client_send)
async with (
server_to_client_receive,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
):
yield client_streams, server_streams
@asynccontextmanager
async def create_connected_server_and_client_session(
server: Server,
read_timeout_seconds: timedelta | None = None,
raise_exceptions: bool = False,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
async with create_client_server_memory_streams() as (
client_streams,
server_streams,
):
client_read, client_write = client_streams
server_read, server_write = server_streams
# Create a cancel scope for the server task
async with anyio.create_task_group() as tg:
tg.start_soon(
lambda: server.run(
server_read,
server_write,
server.create_initialization_options(),
raise_exceptions=raise_exceptions,
)
)
try:
async with ClientSession(
read_stream=client_read,
write_stream=client_write,
read_timeout_seconds=read_timeout_seconds,
) as client_session:
await client_session.initialize()
yield client_session
finally:
tg.cancel_scope.cancel()

View File

@@ -1,8 +1,10 @@
from contextlib import AbstractAsyncContextManager from contextlib import AbstractAsyncContextManager
from datetime import timedelta
from typing import Generic, TypeVar from typing import Generic, TypeVar
import anyio import anyio
import anyio.lowlevel import anyio.lowlevel
import httpx
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import BaseModel from pydantic import BaseModel
@@ -87,6 +89,8 @@ class BaseSession(
write_stream: MemoryObjectSendStream[JSONRPCMessage], write_stream: MemoryObjectSendStream[JSONRPCMessage],
receive_request_type: type[ReceiveRequestT], receive_request_type: type[ReceiveRequestT],
receive_notification_type: type[ReceiveNotificationT], receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
read_timeout_seconds: timedelta | None = None,
) -> None: ) -> None:
self._read_stream = read_stream self._read_stream = read_stream
self._write_stream = write_stream self._write_stream = write_stream
@@ -94,6 +98,7 @@ class BaseSession(
self._request_id = 0 self._request_id = 0
self._receive_request_type = receive_request_type self._receive_request_type = receive_request_type
self._receive_notification_type = receive_notification_type self._receive_notification_type = receive_notification_type
self._read_timeout_seconds = read_timeout_seconds
self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[ anyio.create_memory_object_stream[
@@ -147,7 +152,25 @@ class BaseSession(
await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
response_or_error = await response_stream_reader.receive() try:
with anyio.fail_after(
None if self._read_timeout_seconds is None
else self._read_timeout_seconds.total_seconds()
):
response_or_error = await response_stream_reader.receive()
except TimeoutError:
raise McpError(
ErrorData(
code=httpx.codes.REQUEST_TIMEOUT,
message=(
f"Timed out while waiting for response to "
f"{request.__class__.__name__}. Waited "
f"{self._read_timeout_seconds} seconds."
),
)
)
if isinstance(response_or_error, JSONRPCError): if isinstance(response_or_error, JSONRPCError):
raise McpError(response_or_error.error) raise McpError(response_or_error.error)
else: else:

View File

@@ -141,16 +141,19 @@ class ErrorData(BaseModel):
code: int code: int
"""The error type that occurred.""" """The error type that occurred."""
message: str message: str
""" """
A short description of the error. The message SHOULD be limited to a concise single A short description of the error. The message SHOULD be limited to a concise single
sentence. sentence.
""" """
data: Any | None = None data: Any | None = None
""" """
Additional information about the error. The value of this member is defined by the Additional information about the error. The value of this member is defined by the
sender (e.g. detailed error information, nested errors etc.). sender (e.g. detailed error information, nested errors etc.).
""" """
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")

View File

@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project] [project]
name = "mcp-python" name = "mcp-python"
version = "0.4.0.dev" version = "0.5.0dev"
description = "Model Context Protocol implementation for Python" description = "Model Context Protocol implementation for Python"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"

28
tests/conftest.py Normal file
View File

@@ -0,0 +1,28 @@
import pytest
from pydantic import AnyUrl
from mcp_python.server import Server
from mcp_python.server.types import InitializationOptions
from mcp_python.types import Resource, ServerCapabilities
TEST_INITIALIZATION_OPTIONS = InitializationOptions(
server_name="my_mcp_server",
server_version="0.1.0",
capabilities=ServerCapabilities(),
)
@pytest.fixture
def mcp_server() -> Server:
server = Server(name="test_server")
@server.list_resources()
async def handle_list_resources():
return [
Resource(
uri=AnyUrl("memory://test"),
name="Test Resource",
description="A test resource"
)
]
return server

View File

@@ -0,0 +1,28 @@
import pytest
from typing_extensions import AsyncGenerator
from mcp_python.client.session import ClientSession
from mcp_python.server import Server
from mcp_python.shared.memory import (
create_connected_server_and_client_session,
)
from mcp_python.types import (
EmptyResult,
)
@pytest.fixture
async def client_connected_to_server(
mcp_server: Server,
) -> AsyncGenerator[ClientSession, None]:
async with create_connected_server_and_client_session(mcp_server) as client_session:
yield client_session
@pytest.mark.anyio
async def test_memory_server_and_client_connection(
client_connected_to_server: ClientSession,
):
"""Shows how a client and server can communicate over memory streams."""
response = await client_connected_to_server.send_ping()
assert isinstance(response, EmptyResult)

2
uv.lock generated
View File

@@ -163,7 +163,7 @@ wheels = [
[[package]] [[package]]
name = "mcp-python" name = "mcp-python"
version = "0.3.0.dev0" version = "0.5.0.dev0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "anyio" }, { name = "anyio" },