mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Add in-memory transport (#25)
## Goal Support running an MCP server in the same process as the client, while preserving MCP abstractions. ## Details 1. **(core change)** Adds a new `memory` transport module that enables in-process client-server communication. This includes: - `create_client_server_memory_streams()` to create bidirectional memory streams - `create_connected_server_and_client_session()` to establish an in-process client-server connection 3. (minor) Enhances error handling and timeout support: - Adds configurable read timeouts to sessions via `read_timeout_seconds` parameter - Improves exception handling in the server with a new `raise_exceptions` flag to control whether errors are returned to clients or raised directly - Ensures proper cleanup of request context tokens in error cases 4. (minor) Makes server improvements: - Adds built-in ping handler support
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
@@ -36,8 +38,15 @@ class ClientSession(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
) -> None:
|
||||
super().__init__(read_stream, write_stream, ServerRequest, ServerNotification)
|
||||
super().__init__(
|
||||
read_stream,
|
||||
write_stream,
|
||||
ServerRequest,
|
||||
ServerNotification,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
|
||||
async def initialize(self) -> InitializeResult:
|
||||
from mcp_python.types import (
|
||||
|
||||
@@ -18,6 +18,7 @@ from mcp_python.types import (
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
CompleteRequest,
|
||||
EmptyResult,
|
||||
ErrorData,
|
||||
JSONRPCMessage,
|
||||
ListPromptsRequest,
|
||||
@@ -27,6 +28,7 @@ from mcp_python.types import (
|
||||
ListToolsRequest,
|
||||
ListToolsResult,
|
||||
LoggingLevel,
|
||||
PingRequest,
|
||||
ProgressNotification,
|
||||
Prompt,
|
||||
PromptReference,
|
||||
@@ -52,9 +54,11 @@ request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
|
||||
class Server:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {}
|
||||
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {
|
||||
PingRequest: _ping_handler,
|
||||
}
|
||||
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
||||
logger.info(f"Initializing server '{name}'")
|
||||
logger.debug(f"Initializing server '{name}'")
|
||||
|
||||
def create_initialization_options(self) -> types.InitializationOptions:
|
||||
"""Create initialization options from this server instance."""
|
||||
@@ -63,9 +67,13 @@ class Server:
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
return version(package)
|
||||
v = version(package)
|
||||
if v is not None:
|
||||
return v
|
||||
except Exception:
|
||||
return "unknown"
|
||||
pass
|
||||
|
||||
return "unknown"
|
||||
|
||||
return types.InitializationOptions(
|
||||
server_name=self.name,
|
||||
@@ -330,6 +338,11 @@ class Server:
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
initialization_options: types.InitializationOptions,
|
||||
# When True, exceptions are returned as messages to the client.
|
||||
# When False, exceptions are raised, which will cause the server to shut down
|
||||
# but also make tracing exceptions much easier during testing and when using
|
||||
# in-process servers.
|
||||
raise_exceptions: bool = False,
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
async with ServerSession(
|
||||
@@ -349,6 +362,7 @@ class Server:
|
||||
f"Dispatching request of type {type(req).__name__}"
|
||||
)
|
||||
|
||||
token = None
|
||||
try:
|
||||
# Set our global state that can be retrieved via
|
||||
# app.get_request_context()
|
||||
@@ -360,12 +374,16 @@ class Server:
|
||||
)
|
||||
)
|
||||
response = await handler(req)
|
||||
# Reset the global state after we are done
|
||||
request_ctx.reset(token)
|
||||
except Exception as err:
|
||||
if raise_exceptions:
|
||||
raise err
|
||||
response = ErrorData(
|
||||
code=0, message=str(err), data=None
|
||||
)
|
||||
finally:
|
||||
# Reset the global state after we are done
|
||||
if token is not None:
|
||||
request_ctx.reset(token)
|
||||
|
||||
await message.respond(response)
|
||||
else:
|
||||
@@ -399,3 +417,7 @@ class Server:
|
||||
logger.info(
|
||||
f"Warning: {warning.category.__name__}: {warning.message}"
|
||||
)
|
||||
|
||||
|
||||
async def _ping_handler(request: PingRequest) -> ServerResult:
|
||||
return ServerResult(EmptyResult())
|
||||
|
||||
87
mcp_python/shared/memory.py
Normal file
87
mcp_python/shared/memory.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
In-memory transports
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
from mcp_python.client.session import ClientSession
|
||||
from mcp_python.server import Server
|
||||
from mcp_python.types import JSONRPCMessage
|
||||
|
||||
MessageStream = tuple[
|
||||
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
MemoryObjectSendStream[JSONRPCMessage]
|
||||
]
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_client_server_memory_streams() -> AsyncGenerator[
|
||||
tuple[MessageStream, MessageStream],
|
||||
None
|
||||
]:
|
||||
"""
|
||||
Creates a pair of bidirectional memory streams for client-server communication.
|
||||
|
||||
Returns:
|
||||
A tuple of (client_streams, server_streams) where each is a tuple of
|
||||
(read_stream, write_stream)
|
||||
"""
|
||||
# Create streams for both directions
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
JSONRPCMessage | Exception
|
||||
](1)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
JSONRPCMessage | Exception
|
||||
](1)
|
||||
|
||||
client_streams = (server_to_client_receive, client_to_server_send)
|
||||
server_streams = (client_to_server_receive, server_to_client_send)
|
||||
|
||||
async with (
|
||||
server_to_client_receive,
|
||||
client_to_server_send,
|
||||
client_to_server_receive,
|
||||
server_to_client_send,
|
||||
):
|
||||
yield client_streams, server_streams
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_connected_server_and_client_session(
|
||||
server: Server,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
raise_exceptions: bool = False,
|
||||
) -> AsyncGenerator[ClientSession, None]:
|
||||
"""Creates a ClientSession that is connected to a running MCP server."""
|
||||
async with create_client_server_memory_streams() as (
|
||||
client_streams,
|
||||
server_streams,
|
||||
):
|
||||
client_read, client_write = client_streams
|
||||
server_read, server_write = server_streams
|
||||
|
||||
# Create a cancel scope for the server task
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(
|
||||
lambda: server.run(
|
||||
server_read,
|
||||
server_write,
|
||||
server.create_initialization_options(),
|
||||
raise_exceptions=raise_exceptions,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
async with ClientSession(
|
||||
read_stream=client_read,
|
||||
write_stream=client_write,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
) as client_session:
|
||||
await client_session.initialize()
|
||||
yield client_session
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
@@ -1,8 +1,10 @@
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from datetime import timedelta
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
import httpx
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -87,6 +89,8 @@ class BaseSession(
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
receive_request_type: type[ReceiveRequestT],
|
||||
receive_notification_type: type[ReceiveNotificationT],
|
||||
# If none, reading will never time out
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
) -> None:
|
||||
self._read_stream = read_stream
|
||||
self._write_stream = write_stream
|
||||
@@ -94,6 +98,7 @@ class BaseSession(
|
||||
self._request_id = 0
|
||||
self._receive_request_type = receive_request_type
|
||||
self._receive_notification_type = receive_notification_type
|
||||
self._read_timeout_seconds = read_timeout_seconds
|
||||
|
||||
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
|
||||
anyio.create_memory_object_stream[
|
||||
@@ -147,7 +152,25 @@ class BaseSession(
|
||||
|
||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
|
||||
|
||||
response_or_error = await response_stream_reader.receive()
|
||||
try:
|
||||
with anyio.fail_after(
|
||||
None if self._read_timeout_seconds is None
|
||||
else self._read_timeout_seconds.total_seconds()
|
||||
):
|
||||
response_or_error = await response_stream_reader.receive()
|
||||
except TimeoutError:
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=httpx.codes.REQUEST_TIMEOUT,
|
||||
message=(
|
||||
f"Timed out while waiting for response to "
|
||||
f"{request.__class__.__name__}. Waited "
|
||||
f"{self._read_timeout_seconds} seconds."
|
||||
),
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
if isinstance(response_or_error, JSONRPCError):
|
||||
raise McpError(response_or_error.error)
|
||||
else:
|
||||
|
||||
@@ -141,16 +141,19 @@ class ErrorData(BaseModel):
|
||||
|
||||
code: int
|
||||
"""The error type that occurred."""
|
||||
|
||||
message: str
|
||||
"""
|
||||
A short description of the error. The message SHOULD be limited to a concise single
|
||||
sentence.
|
||||
"""
|
||||
|
||||
data: Any | None = None
|
||||
"""
|
||||
Additional information about the error. The value of this member is defined by the
|
||||
sender (e.g. detailed error information, nested errors etc.).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "mcp-python"
|
||||
version = "0.4.0.dev"
|
||||
version = "0.5.0dev"
|
||||
description = "Model Context Protocol implementation for Python"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
28
tests/conftest.py
Normal file
28
tests/conftest.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import pytest
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp_python.server import Server
|
||||
from mcp_python.server.types import InitializationOptions
|
||||
from mcp_python.types import Resource, ServerCapabilities
|
||||
|
||||
TEST_INITIALIZATION_OPTIONS = InitializationOptions(
|
||||
server_name="my_mcp_server",
|
||||
server_version="0.1.0",
|
||||
capabilities=ServerCapabilities(),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server() -> Server:
|
||||
server = Server(name="test_server")
|
||||
|
||||
@server.list_resources()
|
||||
async def handle_list_resources():
|
||||
return [
|
||||
Resource(
|
||||
uri=AnyUrl("memory://test"),
|
||||
name="Test Resource",
|
||||
description="A test resource"
|
||||
)
|
||||
]
|
||||
|
||||
return server
|
||||
28
tests/shared/test_memory.py
Normal file
28
tests/shared/test_memory.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import pytest
|
||||
from typing_extensions import AsyncGenerator
|
||||
|
||||
from mcp_python.client.session import ClientSession
|
||||
from mcp_python.server import Server
|
||||
from mcp_python.shared.memory import (
|
||||
create_connected_server_and_client_session,
|
||||
)
|
||||
from mcp_python.types import (
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client_connected_to_server(
|
||||
mcp_server: Server,
|
||||
) -> AsyncGenerator[ClientSession, None]:
|
||||
async with create_connected_server_and_client_session(mcp_server) as client_session:
|
||||
yield client_session
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_memory_server_and_client_connection(
|
||||
client_connected_to_server: ClientSession,
|
||||
):
|
||||
"""Shows how a client and server can communicate over memory streams."""
|
||||
response = await client_connected_to_server.send_ping()
|
||||
assert isinstance(response, EmptyResult)
|
||||
Reference in New Issue
Block a user