Format with ruff

This commit is contained in:
David Soria Parra
2024-10-11 11:54:16 +01:00
parent 9475815241
commit fd68df6687
15 changed files with 268 additions and 101 deletions

View File

@@ -69,9 +69,11 @@ class BaseSession(
],
):
"""
Implements an MCP "session" on top of read/write streams, including features like request/response linking, notifications, and progress.
Implements an MCP "session" on top of read/write streams, including features
like request/response linking, notifications, and progress.
This class is an async context manager that automatically starts processing messages when entered.
This class is an async context manager that automatically starts processing
messages when entered.
"""
_response_streams: dict[
@@ -108,7 +110,9 @@ class BaseSession(
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Using BaseSession as a context manager should not block on exit (this would be very surprising behavior), so make sure to cancel the tasks in the task group.
# Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
self._task_group.cancel_scope.cancel()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
@@ -118,9 +122,11 @@ class BaseSession(
result_type: type[ReceiveResultT],
) -> ReceiveResultT:
"""
Sends a request and wait for a response. Raises an McpError if the response contains an error.
Sends a request and wait for a response. Raises an McpError if the
response contains an error.
Do not use this method to emit notifications! Use send_notification() instead.
Do not use this method to emit notifications! Use send_notification()
instead.
"""
request_id = self._request_id
@@ -132,7 +138,9 @@ class BaseSession(
self._response_streams[request_id] = response_stream
jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0", id=request_id, **request.model_dump(by_alias=True, mode="json", exclude_none=True)
jsonrpc="2.0",
id=request_id,
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
)
# TODO: Support progress callbacks
@@ -147,10 +155,12 @@ class BaseSession(
async def send_notification(self, notification: SendNotificationT) -> None:
"""
Emits a notification, which is a one-way message that does not expect a response.
Emits a notification, which is a one-way message that does not expect
a response.
"""
jsonrpc_notification = JSONRPCNotification(
jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True)
jsonrpc="2.0",
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
)
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
@@ -165,7 +175,9 @@ class BaseSession(
jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0",
id=request_id,
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
result=response.model_dump(
by_alias=True, mode="json", exclude_none=True
),
)
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
@@ -180,7 +192,9 @@ class BaseSession(
await self._incoming_message_stream_writer.send(message)
elif isinstance(message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
responder = RequestResponder(
request_id=message.root.id,
@@ -196,7 +210,9 @@ class BaseSession(
await self._incoming_message_stream_writer.send(responder)
elif isinstance(message.root, JSONRPCNotification):
notification = self._receive_notification_type.model_validate(
message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
await self._received_notification(notification)
@@ -208,7 +224,8 @@ class BaseSession(
else:
await self._incoming_message_stream_writer.send(
RuntimeError(
f"Received response with an unknown request ID: {message}"
"Received response with an unknown "
f"request ID: {message}"
)
)
@@ -216,21 +233,25 @@ class BaseSession(
self, responder: RequestResponder[ReceiveRequestT, SendResultT]
) -> None:
"""
Can be overridden by subclasses to handle a request without needing to listen on the message stream.
Can be overridden by subclasses to handle a request without needing to
listen on the message stream.
If the request is responded to within this method, it will not be forwarded on to the message stream.
If the request is responded to within this method, it will not be
forwarded on to the message stream.
"""
async def _received_notification(self, notification: ReceiveNotificationT) -> None:
"""
Can be overridden by subclasses to handle a notification without needing to listen on the message stream.
Can be overridden by subclasses to handle a notification without needing
to listen on the message stream.
"""
async def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
"""
Sends a progress notification for a request that is currently being processed.
Sends a progress notification for a request that is currently being
processed.
"""
@property