mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Initial import
This commit is contained in:
347
mcp_python/server/__init__.py
Normal file
347
mcp_python/server/__init__.py
Normal file
@@ -0,0 +1,347 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp_python.server import types
|
||||
from mcp_python.server.session import ServerSession
|
||||
from mcp_python.server.stdio import stdio_server as stdio_server
|
||||
from mcp_python.shared.context import RequestContext
|
||||
from mcp_python.shared.session import RequestResponder
|
||||
from mcp_python.types import (
|
||||
METHOD_NOT_FOUND,
|
||||
CallToolRequest,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
CompleteRequest,
|
||||
ErrorData,
|
||||
JSONRPCMessage,
|
||||
ListResourcesRequest,
|
||||
ListResourcesResult,
|
||||
LoggingLevel,
|
||||
ProgressNotification,
|
||||
Prompt,
|
||||
PromptReference,
|
||||
ReadResourceRequest,
|
||||
ReadResourceResult,
|
||||
Resource,
|
||||
ResourceReference,
|
||||
ServerResult,
|
||||
SetLevelRequest,
|
||||
SubscribeRequest,
|
||||
UnsubscribeRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
|
||||
"request_ctx"
|
||||
)
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {}
|
||||
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
||||
logger.info(f"Initializing server '{name}'")
|
||||
|
||||
@property
|
||||
def request_context(self) -> RequestContext:
|
||||
"""If called outside of a request context, this will raise a LookupError."""
|
||||
return request_ctx.get()
|
||||
|
||||
def list_prompts(self):
|
||||
from mcp_python.types import ListPromptsRequest, ListPromptsResult
|
||||
|
||||
def decorator(func: Callable[[], Awaitable[list[Prompt]]]):
|
||||
logger.debug(f"Registering handler for PromptListRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
prompts = await func()
|
||||
return ServerResult(ListPromptsResult(prompts=prompts))
|
||||
|
||||
self.request_handlers[ListPromptsRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_prompt(self):
|
||||
from mcp_python.types import (
|
||||
GetPromptRequest,
|
||||
GetPromptResult,
|
||||
ImageContent,
|
||||
Role as Role,
|
||||
SamplingMessage,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[str, dict[str, str] | None], Awaitable[types.PromptResponse]
|
||||
],
|
||||
):
|
||||
logger.debug(f"Registering handler for GetPromptRequest")
|
||||
|
||||
async def handler(req: GetPromptRequest):
|
||||
prompt_get = await func(req.params.name, req.params.arguments)
|
||||
messages = []
|
||||
for message in prompt_get.messages:
|
||||
match message.content:
|
||||
case str() as text_content:
|
||||
content = TextContent(type="text", text=text_content)
|
||||
case types.ImageContent() as img_content:
|
||||
content = ImageContent(
|
||||
type="image",
|
||||
data=img_content.data,
|
||||
mimeType=img_content.mime_type,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unexpected content type: {type(message.content)}"
|
||||
)
|
||||
|
||||
sampling_message = SamplingMessage(
|
||||
role=message.role, content=content
|
||||
)
|
||||
messages.append(sampling_message)
|
||||
|
||||
return ServerResult(
|
||||
GetPromptResult(description=prompt_get.desc, messages=messages)
|
||||
)
|
||||
|
||||
self.request_handlers[GetPromptRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def list_resources(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[Resource]]]):
|
||||
logger.debug(f"Registering handler for ListResourcesRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
resources = await func()
|
||||
return ServerResult(
|
||||
ListResourcesResult(resources=resources, resourceTemplates=None)
|
||||
)
|
||||
|
||||
self.request_handlers[ListResourcesRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def read_resource(self):
|
||||
from mcp_python.types import (
|
||||
BlobResourceContents,
|
||||
TextResourceContents,
|
||||
)
|
||||
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]):
|
||||
logger.debug(f"Registering handler for ReadResourceRequest")
|
||||
|
||||
async def handler(req: ReadResourceRequest):
|
||||
result = await func(req.params.uri)
|
||||
match result:
|
||||
case str(s):
|
||||
content = TextResourceContents(
|
||||
uri=req.params.uri,
|
||||
text=s,
|
||||
mimeType="text/plain",
|
||||
)
|
||||
case bytes(b):
|
||||
import base64
|
||||
|
||||
content = BlobResourceContents(
|
||||
uri=req.params.uri,
|
||||
blob=base64.urlsafe_b64encode(b).decode(),
|
||||
mimeType="application/octet-stream",
|
||||
)
|
||||
|
||||
return ServerResult(
|
||||
ReadResourceResult(
|
||||
contents=[content],
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[ReadResourceRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def set_logging_level(self):
|
||||
from mcp_python.types import EmptyResult
|
||||
|
||||
def decorator(func: Callable[[LoggingLevel], Awaitable[None]]):
|
||||
logger.debug(f"Registering handler for SetLevelRequest")
|
||||
|
||||
async def handler(req: SetLevelRequest):
|
||||
await func(req.params.level)
|
||||
return ServerResult(EmptyResult())
|
||||
|
||||
self.request_handlers[SetLevelRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def subscribe_resource(self):
|
||||
from mcp_python.types import EmptyResult
|
||||
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||
logger.debug(f"Registering handler for SubscribeRequest")
|
||||
|
||||
async def handler(req: SubscribeRequest):
|
||||
await func(req.params.uri)
|
||||
return ServerResult(EmptyResult())
|
||||
|
||||
self.request_handlers[SubscribeRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def unsubscribe_resource(self):
|
||||
from mcp_python.types import EmptyResult
|
||||
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||
logger.debug(f"Registering handler for UnsubscribeRequest")
|
||||
|
||||
async def handler(req: UnsubscribeRequest):
|
||||
await func(req.params.uri)
|
||||
return ServerResult(EmptyResult())
|
||||
|
||||
self.request_handlers[UnsubscribeRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def call_tool(self):
|
||||
from mcp_python.types import CallToolResult
|
||||
|
||||
def decorator(func: Callable[..., Awaitable[Any]]):
|
||||
logger.debug(f"Registering handler for CallToolRequest")
|
||||
|
||||
async def handler(req: CallToolRequest):
|
||||
result = await func(req.params.name, **(req.params.arguments or {}))
|
||||
return ServerResult(CallToolResult(toolResult=result))
|
||||
|
||||
self.request_handlers[CallToolRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def progress_notification(self):
|
||||
def decorator(
|
||||
func: Callable[[str | int, float, float | None], Awaitable[None]],
|
||||
):
|
||||
logger.debug(f"Registering handler for ProgressNotification")
|
||||
|
||||
async def handler(req: ProgressNotification):
|
||||
await func(
|
||||
req.params.progressToken, req.params.progress, req.params.total
|
||||
)
|
||||
|
||||
self.notification_handlers[ProgressNotification] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def completion(self):
|
||||
"""Provides completions for prompts and resource templates"""
|
||||
from mcp_python.types import CompleteResult, Completion, CompletionArgument
|
||||
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[PromptReference | ResourceReference, CompletionArgument],
|
||||
Awaitable[Completion | None],
|
||||
],
|
||||
):
|
||||
logger.debug(f"Registering handler for CompleteRequest")
|
||||
|
||||
async def handler(req: CompleteRequest):
|
||||
completion = await func(req.params.ref, req.params.argument)
|
||||
return ServerResult(
|
||||
CompleteResult(
|
||||
completion=completion
|
||||
if completion is not None
|
||||
else Completion(values=[], total=None, hasMore=None),
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[CompleteRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def run(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
async with ServerSession(read_stream, write_stream) as session:
|
||||
async for message in session.incoming_messages:
|
||||
logger.debug(f"Received message: {message}")
|
||||
|
||||
match message:
|
||||
case RequestResponder(request=ClientRequest(root=req)):
|
||||
logger.info(
|
||||
f"Processing request of type {type(req).__name__}"
|
||||
)
|
||||
if type(req) in self.request_handlers:
|
||||
handler = self.request_handlers[type(req)]
|
||||
logger.debug(
|
||||
f"Dispatching request of type {type(req).__name__}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Set our global state that can be retrieved via
|
||||
# app.get_request_context()
|
||||
token = request_ctx.set(
|
||||
RequestContext(
|
||||
message.request_id,
|
||||
message.request_meta,
|
||||
session,
|
||||
)
|
||||
)
|
||||
response = await handler(req)
|
||||
# Reset the global state after we are done
|
||||
request_ctx.reset(token)
|
||||
except Exception as err:
|
||||
response = ErrorData(
|
||||
code=0, message=str(err), data=None
|
||||
)
|
||||
|
||||
await message.respond(response)
|
||||
else:
|
||||
await message.respond(
|
||||
ErrorData(
|
||||
code=METHOD_NOT_FOUND,
|
||||
message="Method not found",
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("Response sent")
|
||||
case ClientNotification(root=notify):
|
||||
if type(notify) in self.notification_handlers:
|
||||
assert type(notify) in self.notification_handlers
|
||||
|
||||
handler = self.notification_handlers[type(notify)]
|
||||
logger.debug(
|
||||
f"Dispatching notification of type {type(notify).__name__}"
|
||||
)
|
||||
|
||||
try:
|
||||
await handler(notify)
|
||||
except Exception as err:
|
||||
logger.error(
|
||||
f"Uncaught exception in notification handler: {err}"
|
||||
)
|
||||
|
||||
for warning in w:
|
||||
logger.info(
|
||||
f"Warning: {warning.category.__name__}: {warning.message}"
|
||||
)
|
||||
35
mcp_python/server/__main__.py
Normal file
35
mcp_python/server/__main__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp_python.server.session import ServerSession
|
||||
from mcp_python.server.stdio import stdio_server
|
||||
|
||||
if not sys.warnoptions:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("server")
|
||||
|
||||
|
||||
async def receive_loop(session: ServerSession):
|
||||
logger.info("Starting receive loop")
|
||||
async for message in session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
logger.error("Error: %s", message)
|
||||
continue
|
||||
|
||||
logger.info("Received message from client: %s", message)
|
||||
|
||||
|
||||
async def main():
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
async with ServerSession(read_stream, write_stream) as session, write_stream:
|
||||
await receive_loop(session)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
anyio.run(main, backend="trio")
|
||||
203
mcp_python/server/session.py
Normal file
203
mcp_python/server/session.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp_python.shared.session import (
|
||||
BaseSession,
|
||||
RequestResponder,
|
||||
)
|
||||
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
|
||||
from mcp_python.types import (
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
CreateMessageResult,
|
||||
EmptyResult,
|
||||
Implementation,
|
||||
IncludeContext,
|
||||
InitializedNotification,
|
||||
InitializeRequest,
|
||||
InitializeResult,
|
||||
JSONRPCMessage,
|
||||
LoggingLevel,
|
||||
SamplingMessage,
|
||||
ServerCapabilities,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
)
|
||||
|
||||
|
||||
class InitializationState(Enum):
|
||||
NotInitialized = 1
|
||||
Initializing = 2
|
||||
Initialized = 3
|
||||
|
||||
|
||||
class ServerSession(
|
||||
BaseSession[
|
||||
ServerRequest,
|
||||
ServerNotification,
|
||||
ServerResult,
|
||||
ClientRequest,
|
||||
ClientNotification,
|
||||
]
|
||||
):
|
||||
_initialized: InitializationState = InitializationState.NotInitialized
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
||||
) -> None:
|
||||
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
|
||||
self._initialization_state = InitializationState.NotInitialized
|
||||
|
||||
async def _received_request(
|
||||
self, responder: RequestResponder[ClientRequest, ServerResult]
|
||||
):
|
||||
match responder.request.root:
|
||||
case InitializeRequest():
|
||||
self._initialization_state = InitializationState.Initializing
|
||||
await responder.respond(
|
||||
ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=SUPPORTED_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(
|
||||
logging=None,
|
||||
resources=None,
|
||||
tools=None,
|
||||
experimental=None,
|
||||
prompts={},
|
||||
),
|
||||
serverInfo=Implementation(
|
||||
name="mcp_python", version="0.1.0"
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
case _:
|
||||
if self._initialization_state != InitializationState.Initialized:
|
||||
raise RuntimeError(
|
||||
"Received request before initialization was complete"
|
||||
)
|
||||
|
||||
async def _received_notification(self, notification: ClientNotification) -> None:
|
||||
# Need this to avoid ASYNC910
|
||||
await anyio.lowlevel.checkpoint()
|
||||
match notification.root:
|
||||
case InitializedNotification():
|
||||
self._initialization_state = InitializationState.Initialized
|
||||
case _:
|
||||
if self._initialization_state != InitializationState.Initialized:
|
||||
raise RuntimeError(
|
||||
"Received notification before initialization was complete"
|
||||
)
|
||||
|
||||
async def send_log_message(
|
||||
self, level: LoggingLevel, data: Any, logger: str | None = None
|
||||
) -> None:
|
||||
"""Send a log message notification."""
|
||||
from mcp_python.types import (
|
||||
LoggingMessageNotification,
|
||||
LoggingMessageNotificationParams,
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
LoggingMessageNotification(
|
||||
method="notifications/message",
|
||||
params=LoggingMessageNotificationParams(
|
||||
level=level,
|
||||
data=data,
|
||||
logger=logger,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def send_resource_updated(self, uri: AnyUrl) -> None:
|
||||
"""Send a resource updated notification."""
|
||||
from mcp_python.types import (
|
||||
ResourceUpdatedNotification,
|
||||
ResourceUpdatedNotificationParams,
|
||||
)
|
||||
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
ResourceUpdatedNotification(
|
||||
method="notifications/resources/updated",
|
||||
params=ResourceUpdatedNotificationParams(uri=uri),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def request_create_message(
|
||||
self,
|
||||
messages: list[SamplingMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
system_prompt: str | None = None,
|
||||
include_context: IncludeContext | None = None,
|
||||
temperature: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> CreateMessageResult:
|
||||
"""Send a sampling/create_message request."""
|
||||
from mcp_python.types import (
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestParams,
|
||||
)
|
||||
|
||||
return await self.send_request(
|
||||
ServerRequest(
|
||||
CreateMessageRequest(
|
||||
method="sampling/createMessage",
|
||||
params=CreateMessageRequestParams(
|
||||
messages=messages,
|
||||
systemPrompt=system_prompt,
|
||||
includeContext=include_context,
|
||||
temperature=temperature,
|
||||
maxTokens=max_tokens,
|
||||
stopSequences=stop_sequences,
|
||||
metadata=metadata,
|
||||
),
|
||||
)
|
||||
),
|
||||
CreateMessageResult,
|
||||
)
|
||||
|
||||
async def send_ping(self) -> EmptyResult:
|
||||
"""Send a ping request."""
|
||||
from mcp_python.types import PingRequest
|
||||
|
||||
return await self.send_request(
|
||||
ServerRequest(
|
||||
PingRequest(
|
||||
method="ping",
|
||||
)
|
||||
),
|
||||
EmptyResult,
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
from mcp_python.types import ProgressNotification, ProgressNotificationParams
|
||||
|
||||
await self.send_notification(
|
||||
ServerNotification(
|
||||
ProgressNotification(
|
||||
method="notifications/progress",
|
||||
params=ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
133
mcp_python/server/sse.py
Normal file
133
mcp_python/server/sse.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import ValidationError
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from mcp_python.types import JSONRPCMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SseServerTransport:
|
||||
"""
|
||||
SSE server transport for MCP. This class provides _two_ ASGI applications, suitable to be used with a framework like Starlette and a server like Hypercorn:
|
||||
|
||||
1. connect_sse() is an ASGI application which receives incoming GET requests, and sets up a new SSE stream to send server messages to the client.
|
||||
2. handle_post_message() is an ASGI application which receives incoming POST requests, which should contain client messages that link to a previously-established SSE session.
|
||||
"""
|
||||
|
||||
_endpoint: str
|
||||
_read_stream_writers: dict[UUID, MemoryObjectSendStream[JSONRPCMessage | Exception]]
|
||||
|
||||
def __init__(self, endpoint: str) -> None:
|
||||
"""
|
||||
Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL given.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self._endpoint = endpoint
|
||||
self._read_stream_writers = {}
|
||||
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] != "http":
|
||||
logger.error("connect_sse received non-HTTP request")
|
||||
raise ValueError("connect_sse can only handle HTTP requests")
|
||||
|
||||
logger.debug("Setting up SSE connection")
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
session_id = uuid4()
|
||||
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
|
||||
self._read_stream_writers[session_id] = read_stream_writer
|
||||
logger.debug(f"Created new session with ID: {session_id}")
|
||||
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(
|
||||
0, dict[str, Any]
|
||||
)
|
||||
|
||||
async def sse_writer():
|
||||
logger.debug("Starting SSE writer")
|
||||
async with sse_stream_writer, write_stream_reader:
|
||||
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
|
||||
logger.debug(f"Sent endpoint event: {session_uri}")
|
||||
|
||||
async for message in write_stream_reader:
|
||||
logger.debug(f"Sending message via SSE: {message}")
|
||||
await sse_stream_writer.send(
|
||||
{
|
||||
"event": "message",
|
||||
"data": message.model_dump_json(by_alias=True),
|
||||
}
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
response = EventSourceResponse(
|
||||
content=sse_stream_reader, data_sender_callable=sse_writer
|
||||
)
|
||||
logger.debug("Starting SSE response task")
|
||||
tg.start_soon(response, scope, receive, send)
|
||||
|
||||
logger.debug("Yielding read and write streams")
|
||||
yield (read_stream, write_stream)
|
||||
|
||||
async def handle_post_message(
|
||||
self, scope: Scope, receive: Receive, send: Send
|
||||
) -> None:
|
||||
logger.debug("Handling POST message")
|
||||
request = Request(scope, receive)
|
||||
|
||||
session_id_param = request.query_params.get("session_id")
|
||||
if session_id_param is None:
|
||||
logger.warning("Received request without session_id")
|
||||
response = Response("session_id is required", status_code=400)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
try:
|
||||
session_id = UUID(hex=session_id_param)
|
||||
logger.debug(f"Parsed session ID: {session_id}")
|
||||
except ValueError:
|
||||
logger.warning(f"Received invalid session ID: {session_id_param}")
|
||||
response = Response("Invalid session ID", status_code=400)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
writer = self._read_stream_writers.get(session_id)
|
||||
if not writer:
|
||||
logger.warning(f"Could not find session for ID: {session_id}")
|
||||
response = Response("Could not find session", status_code=404)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
json = await request.json()
|
||||
logger.debug(f"Received JSON: {json}")
|
||||
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate(json)
|
||||
logger.debug(f"Validated client message: {message}")
|
||||
except ValidationError as err:
|
||||
logger.error(f"Failed to parse message: {err}")
|
||||
response = Response("Could not parse message", status_code=400)
|
||||
await response(scope, receive, send)
|
||||
await writer.send(err)
|
||||
return
|
||||
|
||||
logger.debug(f"Sending message to writer: {message}")
|
||||
response = Response("Accepted", status_code=202)
|
||||
await response(scope, receive, send)
|
||||
await writer.send(message)
|
||||
60
mcp_python/server/stdio.py
Normal file
60
mcp_python/server/stdio.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
from mcp_python.types import JSONRPCMessage
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def stdio_server(
|
||||
stdin: anyio.AsyncFile | None = None, stdout: anyio.AsyncFile | None = None
|
||||
):
|
||||
"""
|
||||
Server transport for stdio: this communicates with an MCP client by reading from the current process' stdin and writing to stdout.
|
||||
"""
|
||||
# Purposely not using context managers for these, as we don't want to close standard process handles.
|
||||
if not stdin:
|
||||
stdin = anyio.wrap_file(sys.stdin)
|
||||
if not stdout:
|
||||
stdout = anyio.wrap_file(sys.stdout)
|
||||
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async def stdin_reader():
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
async for line in stdin:
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(line)
|
||||
except Exception as exc:
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(message)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async def stdout_writer():
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
json = message.model_dump_json(by_alias=True)
|
||||
await stdout.write(json + "\n")
|
||||
await stdout.flush()
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(stdin_reader)
|
||||
tg.start_soon(stdout_writer)
|
||||
yield read_stream, write_stream
|
||||
27
mcp_python/server/types.py
Normal file
27
mcp_python/server/types.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
This module provides simpler types to use with the server for managing prompts.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from mcp_python.types import Role
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageContent:
|
||||
type: Literal["image"]
|
||||
data: str
|
||||
mime_type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
role: Role
|
||||
content: str | ImageContent
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptResponse:
|
||||
messages: list[Message]
|
||||
desc: str | None = None
|
||||
58
mcp_python/server/websocket.py
Normal file
58
mcp_python/server/websocket.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from starlette.types import Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
from mcp_python.types import JSONRPCMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def websocket_server(scope: Scope, receive: Receive, send: Send):
|
||||
"""
|
||||
WebSocket server transport for MCP. This is an ASGI application, suitable to be used with a framework like Starlette and a server like Hypercorn.
|
||||
"""
|
||||
|
||||
websocket = WebSocket(scope, receive, send)
|
||||
await websocket.accept(subprotocol="mcp")
|
||||
|
||||
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async def ws_reader():
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
async for message in websocket.iter_json():
|
||||
try:
|
||||
client_message = JSONRPCMessage.model_validate(message)
|
||||
except Exception as exc:
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
await read_stream_writer.send(client_message)
|
||||
except anyio.ClosedResourceError:
|
||||
await websocket.close()
|
||||
|
||||
async def ws_writer():
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
obj = message.model_dump(by_alias=True, mode="json")
|
||||
await websocket.send_json(obj)
|
||||
except anyio.ClosedResourceError:
|
||||
await websocket.close()
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(ws_reader)
|
||||
tg.start_soon(ws_writer)
|
||||
yield (read_stream, write_stream)
|
||||
Reference in New Issue
Block a user