Add in-memory transport (#25)

## Goal

Support running an MCP server in the same process as the client, while preserving MCP abstractions.

## Details

1. **(core change)** Adds a new `memory` transport module that enables in-process client-server communication.
This includes:
   - `create_client_server_memory_streams()` to create bidirectional memory streams
   - `create_connected_server_and_client_session()` to establish an in-process client-server connection

3. (minor) Enhances error handling and timeout support:
   - Adds configurable read timeouts to sessions via `read_timeout_seconds` parameter
   - Improves exception handling in the server with a new `raise_exceptions` flag to control whether errors are returned to clients or raised directly
   - Ensures proper cleanup of request context tokens in error cases

4. (minor) Makes server improvements:
   - Adds built-in ping handler support
This commit is contained in:
Nick Merrill
2024-11-05 18:42:41 -05:00
committed by GitHub
parent 1a60e1b7c7
commit 60e9c7a0d7
9 changed files with 210 additions and 10 deletions

View File

@@ -18,6 +18,7 @@ from mcp_python.types import (
ClientNotification,
ClientRequest,
CompleteRequest,
EmptyResult,
ErrorData,
JSONRPCMessage,
ListPromptsRequest,
@@ -27,6 +28,7 @@ from mcp_python.types import (
ListToolsRequest,
ListToolsResult,
LoggingLevel,
PingRequest,
ProgressNotification,
Prompt,
PromptReference,
@@ -52,9 +54,11 @@ request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
class Server:
def __init__(self, name: str):
self.name = name
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {}
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {
PingRequest: _ping_handler,
}
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
logger.info(f"Initializing server '{name}'")
logger.debug(f"Initializing server '{name}'")
def create_initialization_options(self) -> types.InitializationOptions:
"""Create initialization options from this server instance."""
@@ -63,9 +67,13 @@ class Server:
try:
from importlib.metadata import version
return version(package)
v = version(package)
if v is not None:
return v
except Exception:
return "unknown"
pass
return "unknown"
return types.InitializationOptions(
server_name=self.name,
@@ -330,6 +338,11 @@ class Server:
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
initialization_options: types.InitializationOptions,
# When True, exceptions are returned as messages to the client.
# When False, exceptions are raised, which will cause the server to shut down
# but also make tracing exceptions much easier during testing and when using
# in-process servers.
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
async with ServerSession(
@@ -349,6 +362,7 @@ class Server:
f"Dispatching request of type {type(req).__name__}"
)
token = None
try:
# Set our global state that can be retrieved via
# app.get_request_context()
@@ -360,12 +374,16 @@ class Server:
)
)
response = await handler(req)
# Reset the global state after we are done
request_ctx.reset(token)
except Exception as err:
if raise_exceptions:
raise err
response = ErrorData(
code=0, message=str(err), data=None
)
finally:
# Reset the global state after we are done
if token is not None:
request_ctx.reset(token)
await message.respond(response)
else:
@@ -399,3 +417,7 @@ class Server:
logger.info(
f"Warning: {warning.category.__name__}: {warning.message}"
)
async def _ping_handler(request: PingRequest) -> ServerResult:
return ServerResult(EmptyResult())