Use 120 characters instead of 88 (#856)

This commit is contained in:
Marcelo Trylesinski
2025-06-11 02:45:50 -07:00
committed by GitHub
parent f7265f7b91
commit 543961968c
90 changed files with 687 additions and 2142 deletions

View File

@@ -38,9 +38,7 @@ SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar(
"ReceiveNotificationT", ClientNotification, ServerNotification
)
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
RequestId = str | int
@@ -48,9 +46,7 @@ RequestId = str | int
class ProgressFnT(Protocol):
"""Protocol for progress notification callbacks."""
async def __call__(
self, progress: float, total: float | None, message: str | None
) -> None: ...
async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ...
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
@@ -177,9 +173,7 @@ class BaseSession(
messages when entered.
"""
_response_streams: dict[
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
]
_response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]]
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_progress_callbacks: dict[RequestId, ProgressFnT]
@@ -242,9 +236,7 @@ class BaseSession(
request_id = self._request_id
self._request_id = request_id + 1
response_stream, response_stream_reader = anyio.create_memory_object_stream[
JSONRPCResponse | JSONRPCError
](1)
response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1)
self._response_streams[request_id] = response_stream
# Set up progress token if progress callback is provided
@@ -266,11 +258,7 @@ class BaseSession(
**request_data,
)
await self._write_stream.send(
SessionMessage(
message=JSONRPCMessage(jsonrpc_request), metadata=metadata
)
)
await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
# request read timeout takes precedence over session read timeout
timeout = None
@@ -322,15 +310,11 @@ class BaseSession(
)
session_message = SessionMessage(
message=JSONRPCMessage(jsonrpc_notification),
metadata=ServerMessageMetadata(related_request_id=related_request_id)
if related_request_id
else None,
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
)
await self._write_stream.send(session_message)
async def _send_response(
self, request_id: RequestId, response: SendResultT | ErrorData
) -> None:
async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
@@ -339,9 +323,7 @@ class BaseSession(
jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0",
id=request_id,
result=response.model_dump(
by_alias=True, mode="json", exclude_none=True
),
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
await self._write_stream.send(session_message)
@@ -357,19 +339,14 @@ class BaseSession(
elif isinstance(message.message.root, JSONRPCRequest):
try:
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta
if validated_request.root.params
else None,
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(
r.request_id, None),
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
)
self._in_flight[responder.request_id] = responder
@@ -381,9 +358,7 @@ class BaseSession(
# For request validation errors, send a proper JSON-RPC error
# response instead of crashing the server
logging.warning(f"Failed to validate request: {e}")
logging.debug(
f"Message that failed validation: {message.message.root}"
)
logging.debug(f"Message that failed validation: {message.message.root}")
error_response = JSONRPCError(
jsonrpc="2.0",
id=message.message.root.id,
@@ -393,16 +368,13 @@ class BaseSession(
data="",
),
)
session_message = SessionMessage(
message=JSONRPCMessage(error_response))
session_message = SessionMessage(message=JSONRPCMessage(error_response))
await self._write_stream.send(session_message)
elif isinstance(message.message.root, JSONRPCNotification):
try:
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
@@ -427,8 +399,7 @@ class BaseSession(
except Exception as e:
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. "
f"Message was: {message.message.root}"
f"Failed to validate notification: {e}. " f"Message was: {message.message.root}"
)
else: # Response or error
stream = self._response_streams.pop(message.message.root.id, None)
@@ -436,10 +407,7 @@ class BaseSession(
await stream.send(message.message.root)
else:
await self._handle_incoming(
RuntimeError(
"Received response with an unknown "
f"request ID: {message}"
)
RuntimeError("Received response with an unknown " f"request ID: {message}")
)
# after the read stream is closed, we need to send errors
@@ -450,9 +418,7 @@ class BaseSession(
await stream.aclose()
self._response_streams.clear()
async def _received_request(
self, responder: RequestResponder[ReceiveRequestT, SendResultT]
) -> None:
async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
"""
Can be overridden by subclasses to handle a request without needing to
listen on the message stream.
@@ -481,9 +447,7 @@ class BaseSession(
async def _handle_incoming(
self,
req: RequestResponder[ReceiveRequestT, SendResultT]
| ReceiveNotificationT
| Exception,
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
) -> None:
"""A generic handler for incoming messages. Overwritten by subclasses."""
pass