Add in-memory transport (#25)

## Goal

Support running an MCP server in the same process as the client, while preserving MCP abstractions.

## Details

1. **(core change)** Adds a new `memory` transport module that enables in-process client-server communication.
This includes:
   - `create_client_server_memory_streams()` to create bidirectional memory streams
   - `create_connected_server_and_client_session()` to establish an in-process client-server connection

3. (minor) Enhances error handling and timeout support:
   - Adds configurable read timeouts to sessions via `read_timeout_seconds` parameter
   - Improves exception handling in the server with a new `raise_exceptions` flag to control whether errors are returned to clients or raised directly
   - Ensures proper cleanup of request context tokens in error cases

4. (minor) Makes server improvements:
   - Adds built-in ping handler support
This commit is contained in:
Nick Merrill
2024-11-05 18:42:41 -05:00
committed by GitHub
parent 1a60e1b7c7
commit 60e9c7a0d7
9 changed files with 210 additions and 10 deletions

View File

@@ -1,3 +1,5 @@
from datetime import timedelta
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
@@ -36,8 +38,15 @@ class ClientSession(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
read_timeout_seconds: timedelta | None = None,
) -> None:
super().__init__(read_stream, write_stream, ServerRequest, ServerNotification)
super().__init__(
read_stream,
write_stream,
ServerRequest,
ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
async def initialize(self) -> InitializeResult:
from mcp_python.types import (

View File

@@ -18,6 +18,7 @@ from mcp_python.types import (
ClientNotification,
ClientRequest,
CompleteRequest,
EmptyResult,
ErrorData,
JSONRPCMessage,
ListPromptsRequest,
@@ -27,6 +28,7 @@ from mcp_python.types import (
ListToolsRequest,
ListToolsResult,
LoggingLevel,
PingRequest,
ProgressNotification,
Prompt,
PromptReference,
@@ -52,9 +54,11 @@ request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
class Server:
def __init__(self, name: str):
self.name = name
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {}
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {
PingRequest: _ping_handler,
}
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
logger.info(f"Initializing server '{name}'")
logger.debug(f"Initializing server '{name}'")
def create_initialization_options(self) -> types.InitializationOptions:
"""Create initialization options from this server instance."""
@@ -63,9 +67,13 @@ class Server:
try:
from importlib.metadata import version
return version(package)
v = version(package)
if v is not None:
return v
except Exception:
return "unknown"
pass
return "unknown"
return types.InitializationOptions(
server_name=self.name,
@@ -330,6 +338,11 @@ class Server:
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
initialization_options: types.InitializationOptions,
# When True, exceptions are returned as messages to the client.
# When False, exceptions are raised, which will cause the server to shut down
# but also make tracing exceptions much easier during testing and when using
# in-process servers.
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
async with ServerSession(
@@ -349,6 +362,7 @@ class Server:
f"Dispatching request of type {type(req).__name__}"
)
token = None
try:
# Set our global state that can be retrieved via
# app.get_request_context()
@@ -360,12 +374,16 @@ class Server:
)
)
response = await handler(req)
# Reset the global state after we are done
request_ctx.reset(token)
except Exception as err:
if raise_exceptions:
raise err
response = ErrorData(
code=0, message=str(err), data=None
)
finally:
# Reset the global state after we are done
if token is not None:
request_ctx.reset(token)
await message.respond(response)
else:
@@ -399,3 +417,7 @@ class Server:
logger.info(
f"Warning: {warning.category.__name__}: {warning.message}"
)
async def _ping_handler(request: PingRequest) -> ServerResult:
return ServerResult(EmptyResult())

View File

@@ -0,0 +1,87 @@
"""
In-memory transports
"""
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import AsyncGenerator
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp_python.client.session import ClientSession
from mcp_python.server import Server
from mcp_python.types import JSONRPCMessage
MessageStream = tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage]
]
@asynccontextmanager
async def create_client_server_memory_streams() -> AsyncGenerator[
tuple[MessageStream, MessageStream],
None
]:
"""
Creates a pair of bidirectional memory streams for client-server communication.
Returns:
A tuple of (client_streams, server_streams) where each is a tuple of
(read_stream, write_stream)
"""
# Create streams for both directions
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
JSONRPCMessage | Exception
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
JSONRPCMessage | Exception
](1)
client_streams = (server_to_client_receive, client_to_server_send)
server_streams = (client_to_server_receive, server_to_client_send)
async with (
server_to_client_receive,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
):
yield client_streams, server_streams
@asynccontextmanager
async def create_connected_server_and_client_session(
server: Server,
read_timeout_seconds: timedelta | None = None,
raise_exceptions: bool = False,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
async with create_client_server_memory_streams() as (
client_streams,
server_streams,
):
client_read, client_write = client_streams
server_read, server_write = server_streams
# Create a cancel scope for the server task
async with anyio.create_task_group() as tg:
tg.start_soon(
lambda: server.run(
server_read,
server_write,
server.create_initialization_options(),
raise_exceptions=raise_exceptions,
)
)
try:
async with ClientSession(
read_stream=client_read,
write_stream=client_write,
read_timeout_seconds=read_timeout_seconds,
) as client_session:
await client_session.initialize()
yield client_session
finally:
tg.cancel_scope.cancel()

View File

@@ -1,8 +1,10 @@
from contextlib import AbstractAsyncContextManager
from datetime import timedelta
from typing import Generic, TypeVar
import anyio
import anyio.lowlevel
import httpx
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import BaseModel
@@ -87,6 +89,8 @@ class BaseSession(
write_stream: MemoryObjectSendStream[JSONRPCMessage],
receive_request_type: type[ReceiveRequestT],
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
read_timeout_seconds: timedelta | None = None,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
@@ -94,6 +98,7 @@ class BaseSession(
self._request_id = 0
self._receive_request_type = receive_request_type
self._receive_notification_type = receive_notification_type
self._read_timeout_seconds = read_timeout_seconds
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[
@@ -147,7 +152,25 @@ class BaseSession(
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
response_or_error = await response_stream_reader.receive()
try:
with anyio.fail_after(
None if self._read_timeout_seconds is None
else self._read_timeout_seconds.total_seconds()
):
response_or_error = await response_stream_reader.receive()
except TimeoutError:
raise McpError(
ErrorData(
code=httpx.codes.REQUEST_TIMEOUT,
message=(
f"Timed out while waiting for response to "
f"{request.__class__.__name__}. Waited "
f"{self._read_timeout_seconds} seconds."
),
)
)
if isinstance(response_or_error, JSONRPCError):
raise McpError(response_or_error.error)
else:

View File

@@ -141,16 +141,19 @@ class ErrorData(BaseModel):
code: int
"""The error type that occurred."""
message: str
"""
A short description of the error. The message SHOULD be limited to a concise single
sentence.
"""
data: Any | None = None
"""
Additional information about the error. The value of this member is defined by the
sender (e.g. detailed error information, nested errors etc.).
"""
model_config = ConfigDict(extra="allow")