mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Use 120 characters instead of 88 (#856)
This commit is contained in:
committed by
GitHub
parent
f7265f7b91
commit
543961968c
@@ -38,9 +38,7 @@ SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
||||
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
||||
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
||||
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
||||
ReceiveNotificationT = TypeVar(
|
||||
"ReceiveNotificationT", ClientNotification, ServerNotification
|
||||
)
|
||||
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
|
||||
|
||||
RequestId = str | int
|
||||
|
||||
@@ -48,9 +46,7 @@ RequestId = str | int
|
||||
class ProgressFnT(Protocol):
|
||||
"""Protocol for progress notification callbacks."""
|
||||
|
||||
async def __call__(
|
||||
self, progress: float, total: float | None, message: str | None
|
||||
) -> None: ...
|
||||
async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ...
|
||||
|
||||
|
||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
@@ -177,9 +173,7 @@ class BaseSession(
|
||||
messages when entered.
|
||||
"""
|
||||
|
||||
_response_streams: dict[
|
||||
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
|
||||
]
|
||||
_response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]]
|
||||
_request_id: int
|
||||
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
||||
_progress_callbacks: dict[RequestId, ProgressFnT]
|
||||
@@ -242,9 +236,7 @@ class BaseSession(
|
||||
request_id = self._request_id
|
||||
self._request_id = request_id + 1
|
||||
|
||||
response_stream, response_stream_reader = anyio.create_memory_object_stream[
|
||||
JSONRPCResponse | JSONRPCError
|
||||
](1)
|
||||
response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1)
|
||||
self._response_streams[request_id] = response_stream
|
||||
|
||||
# Set up progress token if progress callback is provided
|
||||
@@ -266,11 +258,7 @@ class BaseSession(
|
||||
**request_data,
|
||||
)
|
||||
|
||||
await self._write_stream.send(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(jsonrpc_request), metadata=metadata
|
||||
)
|
||||
)
|
||||
await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
|
||||
|
||||
# request read timeout takes precedence over session read timeout
|
||||
timeout = None
|
||||
@@ -322,15 +310,11 @@ class BaseSession(
|
||||
)
|
||||
session_message = SessionMessage(
|
||||
message=JSONRPCMessage(jsonrpc_notification),
|
||||
metadata=ServerMessageMetadata(related_request_id=related_request_id)
|
||||
if related_request_id
|
||||
else None,
|
||||
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
|
||||
)
|
||||
await self._write_stream.send(session_message)
|
||||
|
||||
async def _send_response(
|
||||
self, request_id: RequestId, response: SendResultT | ErrorData
|
||||
) -> None:
|
||||
async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
|
||||
if isinstance(response, ErrorData):
|
||||
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
||||
@@ -339,9 +323,7 @@ class BaseSession(
|
||||
jsonrpc_response = JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
result=response.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
),
|
||||
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
|
||||
await self._write_stream.send(session_message)
|
||||
@@ -357,19 +339,14 @@ class BaseSession(
|
||||
elif isinstance(message.message.root, JSONRPCRequest):
|
||||
try:
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.message.root.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
responder = RequestResponder(
|
||||
request_id=message.message.root.id,
|
||||
request_meta=validated_request.root.params.meta
|
||||
if validated_request.root.params
|
||||
else None,
|
||||
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
|
||||
request=validated_request,
|
||||
session=self,
|
||||
on_complete=lambda r: self._in_flight.pop(
|
||||
r.request_id, None),
|
||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||
message_metadata=message.metadata,
|
||||
)
|
||||
self._in_flight[responder.request_id] = responder
|
||||
@@ -381,9 +358,7 @@ class BaseSession(
|
||||
# For request validation errors, send a proper JSON-RPC error
|
||||
# response instead of crashing the server
|
||||
logging.warning(f"Failed to validate request: {e}")
|
||||
logging.debug(
|
||||
f"Message that failed validation: {message.message.root}"
|
||||
)
|
||||
logging.debug(f"Message that failed validation: {message.message.root}")
|
||||
error_response = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=message.message.root.id,
|
||||
@@ -393,16 +368,13 @@ class BaseSession(
|
||||
data="",
|
||||
),
|
||||
)
|
||||
session_message = SessionMessage(
|
||||
message=JSONRPCMessage(error_response))
|
||||
session_message = SessionMessage(message=JSONRPCMessage(error_response))
|
||||
await self._write_stream.send(session_message)
|
||||
|
||||
elif isinstance(message.message.root, JSONRPCNotification):
|
||||
try:
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.message.root.model_dump(
|
||||
by_alias=True, mode="json", exclude_none=True
|
||||
)
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
# Handle cancellation notifications
|
||||
if isinstance(notification.root, CancelledNotification):
|
||||
@@ -427,8 +399,7 @@ class BaseSession(
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logging.warning(
|
||||
f"Failed to validate notification: {e}. "
|
||||
f"Message was: {message.message.root}"
|
||||
f"Failed to validate notification: {e}. " f"Message was: {message.message.root}"
|
||||
)
|
||||
else: # Response or error
|
||||
stream = self._response_streams.pop(message.message.root.id, None)
|
||||
@@ -436,10 +407,7 @@ class BaseSession(
|
||||
await stream.send(message.message.root)
|
||||
else:
|
||||
await self._handle_incoming(
|
||||
RuntimeError(
|
||||
"Received response with an unknown "
|
||||
f"request ID: {message}"
|
||||
)
|
||||
RuntimeError("Received response with an unknown " f"request ID: {message}")
|
||||
)
|
||||
|
||||
# after the read stream is closed, we need to send errors
|
||||
@@ -450,9 +418,7 @@ class BaseSession(
|
||||
await stream.aclose()
|
||||
self._response_streams.clear()
|
||||
|
||||
async def _received_request(
|
||||
self, responder: RequestResponder[ReceiveRequestT, SendResultT]
|
||||
) -> None:
|
||||
async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a request without needing to
|
||||
listen on the message stream.
|
||||
@@ -481,9 +447,7 @@ class BaseSession(
|
||||
|
||||
async def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[ReceiveRequestT, SendResultT]
|
||||
| ReceiveNotificationT
|
||||
| Exception,
|
||||
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
|
||||
) -> None:
|
||||
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user