mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 23:04:25 +01:00
Add in-memory transport (#25)
## Goal Support running an MCP server in the same process as the client, while preserving MCP abstractions. ## Details 1. **(core change)** Adds a new `memory` transport module that enables in-process client-server communication. This includes: - `create_client_server_memory_streams()` to create bidirectional memory streams - `create_connected_server_and_client_session()` to establish an in-process client-server connection 3. (minor) Enhances error handling and timeout support: - Adds configurable read timeouts to sessions via `read_timeout_seconds` parameter - Improves exception handling in the server with a new `raise_exceptions` flag to control whether errors are returned to clients or raised directly - Ensures proper cleanup of request context tokens in error cases 4. (minor) Makes server improvements: - Adds built-in ping handler support
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
from pydantic import AnyUrl
|
from pydantic import AnyUrl
|
||||||
|
|
||||||
@@ -36,8 +38,15 @@ class ClientSession(
|
|||||||
self,
|
self,
|
||||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||||
|
read_timeout_seconds: timedelta | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(read_stream, write_stream, ServerRequest, ServerNotification)
|
super().__init__(
|
||||||
|
read_stream,
|
||||||
|
write_stream,
|
||||||
|
ServerRequest,
|
||||||
|
ServerNotification,
|
||||||
|
read_timeout_seconds=read_timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
async def initialize(self) -> InitializeResult:
|
async def initialize(self) -> InitializeResult:
|
||||||
from mcp_python.types import (
|
from mcp_python.types import (
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from mcp_python.types import (
|
|||||||
ClientNotification,
|
ClientNotification,
|
||||||
ClientRequest,
|
ClientRequest,
|
||||||
CompleteRequest,
|
CompleteRequest,
|
||||||
|
EmptyResult,
|
||||||
ErrorData,
|
ErrorData,
|
||||||
JSONRPCMessage,
|
JSONRPCMessage,
|
||||||
ListPromptsRequest,
|
ListPromptsRequest,
|
||||||
@@ -27,6 +28,7 @@ from mcp_python.types import (
|
|||||||
ListToolsRequest,
|
ListToolsRequest,
|
||||||
ListToolsResult,
|
ListToolsResult,
|
||||||
LoggingLevel,
|
LoggingLevel,
|
||||||
|
PingRequest,
|
||||||
ProgressNotification,
|
ProgressNotification,
|
||||||
Prompt,
|
Prompt,
|
||||||
PromptReference,
|
PromptReference,
|
||||||
@@ -52,9 +54,11 @@ request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
|
|||||||
class Server:
|
class Server:
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {}
|
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {
|
||||||
|
PingRequest: _ping_handler,
|
||||||
|
}
|
||||||
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
||||||
logger.info(f"Initializing server '{name}'")
|
logger.debug(f"Initializing server '{name}'")
|
||||||
|
|
||||||
def create_initialization_options(self) -> types.InitializationOptions:
|
def create_initialization_options(self) -> types.InitializationOptions:
|
||||||
"""Create initialization options from this server instance."""
|
"""Create initialization options from this server instance."""
|
||||||
@@ -63,9 +67,13 @@ class Server:
|
|||||||
try:
|
try:
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
|
||||||
return version(package)
|
v = version(package)
|
||||||
|
if v is not None:
|
||||||
|
return v
|
||||||
except Exception:
|
except Exception:
|
||||||
return "unknown"
|
pass
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
return types.InitializationOptions(
|
return types.InitializationOptions(
|
||||||
server_name=self.name,
|
server_name=self.name,
|
||||||
@@ -330,6 +338,11 @@ class Server:
|
|||||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||||
initialization_options: types.InitializationOptions,
|
initialization_options: types.InitializationOptions,
|
||||||
|
# When True, exceptions are returned as messages to the client.
|
||||||
|
# When False, exceptions are raised, which will cause the server to shut down
|
||||||
|
# but also make tracing exceptions much easier during testing and when using
|
||||||
|
# in-process servers.
|
||||||
|
raise_exceptions: bool = False,
|
||||||
):
|
):
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
async with ServerSession(
|
async with ServerSession(
|
||||||
@@ -349,6 +362,7 @@ class Server:
|
|||||||
f"Dispatching request of type {type(req).__name__}"
|
f"Dispatching request of type {type(req).__name__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
token = None
|
||||||
try:
|
try:
|
||||||
# Set our global state that can be retrieved via
|
# Set our global state that can be retrieved via
|
||||||
# app.get_request_context()
|
# app.get_request_context()
|
||||||
@@ -360,12 +374,16 @@ class Server:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
response = await handler(req)
|
response = await handler(req)
|
||||||
# Reset the global state after we are done
|
|
||||||
request_ctx.reset(token)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
if raise_exceptions:
|
||||||
|
raise err
|
||||||
response = ErrorData(
|
response = ErrorData(
|
||||||
code=0, message=str(err), data=None
|
code=0, message=str(err), data=None
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
# Reset the global state after we are done
|
||||||
|
if token is not None:
|
||||||
|
request_ctx.reset(token)
|
||||||
|
|
||||||
await message.respond(response)
|
await message.respond(response)
|
||||||
else:
|
else:
|
||||||
@@ -399,3 +417,7 @@ class Server:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Warning: {warning.category.__name__}: {warning.message}"
|
f"Warning: {warning.category.__name__}: {warning.message}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _ping_handler(request: PingRequest) -> ServerResult:
|
||||||
|
return ServerResult(EmptyResult())
|
||||||
|
|||||||
87
mcp_python/shared/memory.py
Normal file
87
mcp_python/shared/memory.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""
|
||||||
|
In-memory transports
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
|
|
||||||
|
from mcp_python.client.session import ClientSession
|
||||||
|
from mcp_python.server import Server
|
||||||
|
from mcp_python.types import JSONRPCMessage
|
||||||
|
|
||||||
|
MessageStream = tuple[
|
||||||
|
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||||
|
MemoryObjectSendStream[JSONRPCMessage]
|
||||||
|
]
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_client_server_memory_streams() -> AsyncGenerator[
|
||||||
|
tuple[MessageStream, MessageStream],
|
||||||
|
None
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Creates a pair of bidirectional memory streams for client-server communication.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (client_streams, server_streams) where each is a tuple of
|
||||||
|
(read_stream, write_stream)
|
||||||
|
"""
|
||||||
|
# Create streams for both directions
|
||||||
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||||
|
JSONRPCMessage | Exception
|
||||||
|
](1)
|
||||||
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||||
|
JSONRPCMessage | Exception
|
||||||
|
](1)
|
||||||
|
|
||||||
|
client_streams = (server_to_client_receive, client_to_server_send)
|
||||||
|
server_streams = (client_to_server_receive, server_to_client_send)
|
||||||
|
|
||||||
|
async with (
|
||||||
|
server_to_client_receive,
|
||||||
|
client_to_server_send,
|
||||||
|
client_to_server_receive,
|
||||||
|
server_to_client_send,
|
||||||
|
):
|
||||||
|
yield client_streams, server_streams
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_connected_server_and_client_session(
|
||||||
|
server: Server,
|
||||||
|
read_timeout_seconds: timedelta | None = None,
|
||||||
|
raise_exceptions: bool = False,
|
||||||
|
) -> AsyncGenerator[ClientSession, None]:
|
||||||
|
"""Creates a ClientSession that is connected to a running MCP server."""
|
||||||
|
async with create_client_server_memory_streams() as (
|
||||||
|
client_streams,
|
||||||
|
server_streams,
|
||||||
|
):
|
||||||
|
client_read, client_write = client_streams
|
||||||
|
server_read, server_write = server_streams
|
||||||
|
|
||||||
|
# Create a cancel scope for the server task
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
tg.start_soon(
|
||||||
|
lambda: server.run(
|
||||||
|
server_read,
|
||||||
|
server_write,
|
||||||
|
server.create_initialization_options(),
|
||||||
|
raise_exceptions=raise_exceptions,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with ClientSession(
|
||||||
|
read_stream=client_read,
|
||||||
|
write_stream=client_write,
|
||||||
|
read_timeout_seconds=read_timeout_seconds,
|
||||||
|
) as client_session:
|
||||||
|
await client_session.initialize()
|
||||||
|
yield client_session
|
||||||
|
finally:
|
||||||
|
tg.cancel_scope.cancel()
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
from contextlib import AbstractAsyncContextManager
|
from contextlib import AbstractAsyncContextManager
|
||||||
|
from datetime import timedelta
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import anyio.lowlevel
|
import anyio.lowlevel
|
||||||
|
import httpx
|
||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -87,6 +89,8 @@ class BaseSession(
|
|||||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||||
receive_request_type: type[ReceiveRequestT],
|
receive_request_type: type[ReceiveRequestT],
|
||||||
receive_notification_type: type[ReceiveNotificationT],
|
receive_notification_type: type[ReceiveNotificationT],
|
||||||
|
# If none, reading will never time out
|
||||||
|
read_timeout_seconds: timedelta | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._read_stream = read_stream
|
self._read_stream = read_stream
|
||||||
self._write_stream = write_stream
|
self._write_stream = write_stream
|
||||||
@@ -94,6 +98,7 @@ class BaseSession(
|
|||||||
self._request_id = 0
|
self._request_id = 0
|
||||||
self._receive_request_type = receive_request_type
|
self._receive_request_type = receive_request_type
|
||||||
self._receive_notification_type = receive_notification_type
|
self._receive_notification_type = receive_notification_type
|
||||||
|
self._read_timeout_seconds = read_timeout_seconds
|
||||||
|
|
||||||
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
|
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
|
||||||
anyio.create_memory_object_stream[
|
anyio.create_memory_object_stream[
|
||||||
@@ -147,7 +152,25 @@ class BaseSession(
|
|||||||
|
|
||||||
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
|
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
|
||||||
|
|
||||||
response_or_error = await response_stream_reader.receive()
|
try:
|
||||||
|
with anyio.fail_after(
|
||||||
|
None if self._read_timeout_seconds is None
|
||||||
|
else self._read_timeout_seconds.total_seconds()
|
||||||
|
):
|
||||||
|
response_or_error = await response_stream_reader.receive()
|
||||||
|
except TimeoutError:
|
||||||
|
raise McpError(
|
||||||
|
ErrorData(
|
||||||
|
code=httpx.codes.REQUEST_TIMEOUT,
|
||||||
|
message=(
|
||||||
|
f"Timed out while waiting for response to "
|
||||||
|
f"{request.__class__.__name__}. Waited "
|
||||||
|
f"{self._read_timeout_seconds} seconds."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(response_or_error, JSONRPCError):
|
if isinstance(response_or_error, JSONRPCError):
|
||||||
raise McpError(response_or_error.error)
|
raise McpError(response_or_error.error)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -141,16 +141,19 @@ class ErrorData(BaseModel):
|
|||||||
|
|
||||||
code: int
|
code: int
|
||||||
"""The error type that occurred."""
|
"""The error type that occurred."""
|
||||||
|
|
||||||
message: str
|
message: str
|
||||||
"""
|
"""
|
||||||
A short description of the error. The message SHOULD be limited to a concise single
|
A short description of the error. The message SHOULD be limited to a concise single
|
||||||
sentence.
|
sentence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: Any | None = None
|
data: Any | None = None
|
||||||
"""
|
"""
|
||||||
Additional information about the error. The value of this member is defined by the
|
Additional information about the error. The value of this member is defined by the
|
||||||
sender (e.g. detailed error information, nested errors etc.).
|
sender (e.g. detailed error information, nested errors etc.).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "mcp-python"
|
name = "mcp-python"
|
||||||
version = "0.4.0.dev"
|
version = "0.5.0dev"
|
||||||
description = "Model Context Protocol implementation for Python"
|
description = "Model Context Protocol implementation for Python"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
28
tests/conftest.py
Normal file
28
tests/conftest.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import pytest
|
||||||
|
from pydantic import AnyUrl
|
||||||
|
|
||||||
|
from mcp_python.server import Server
|
||||||
|
from mcp_python.server.types import InitializationOptions
|
||||||
|
from mcp_python.types import Resource, ServerCapabilities
|
||||||
|
|
||||||
|
TEST_INITIALIZATION_OPTIONS = InitializationOptions(
|
||||||
|
server_name="my_mcp_server",
|
||||||
|
server_version="0.1.0",
|
||||||
|
capabilities=ServerCapabilities(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mcp_server() -> Server:
|
||||||
|
server = Server(name="test_server")
|
||||||
|
|
||||||
|
@server.list_resources()
|
||||||
|
async def handle_list_resources():
|
||||||
|
return [
|
||||||
|
Resource(
|
||||||
|
uri=AnyUrl("memory://test"),
|
||||||
|
name="Test Resource",
|
||||||
|
description="A test resource"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return server
|
||||||
28
tests/shared/test_memory.py
Normal file
28
tests/shared/test_memory.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import pytest
|
||||||
|
from typing_extensions import AsyncGenerator
|
||||||
|
|
||||||
|
from mcp_python.client.session import ClientSession
|
||||||
|
from mcp_python.server import Server
|
||||||
|
from mcp_python.shared.memory import (
|
||||||
|
create_connected_server_and_client_session,
|
||||||
|
)
|
||||||
|
from mcp_python.types import (
|
||||||
|
EmptyResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client_connected_to_server(
|
||||||
|
mcp_server: Server,
|
||||||
|
) -> AsyncGenerator[ClientSession, None]:
|
||||||
|
async with create_connected_server_and_client_session(mcp_server) as client_session:
|
||||||
|
yield client_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_memory_server_and_client_connection(
|
||||||
|
client_connected_to_server: ClientSession,
|
||||||
|
):
|
||||||
|
"""Shows how a client and server can communicate over memory streams."""
|
||||||
|
response = await client_connected_to_server.send_ping()
|
||||||
|
assert isinstance(response, EmptyResult)
|
||||||
2
uv.lock
generated
2
uv.lock
generated
@@ -163,7 +163,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mcp-python"
|
name = "mcp-python"
|
||||||
version = "0.3.0.dev0"
|
version = "0.5.0.dev0"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "anyio" },
|
{ name = "anyio" },
|
||||||
|
|||||||
Reference in New Issue
Block a user