Merge pull request #8 from modelcontextprotocol/justin/omit-nulls

Exclude `None`s when serializing models
This commit is contained in:
Justin Spahr-Summers
2024-10-03 11:04:42 +01:00
committed by GitHub
9 changed files with 15 additions and 15 deletions

View File

@@ -104,7 +104,7 @@ async def sse_client(url: str, headers: dict[str, Any] | None = None, timeout: f
logger.debug(f"Sending client message: {message}") logger.debug(f"Sending client message: {message}")
response = await client.post( response = await client.post(
endpoint_url, endpoint_url,
json=message.model_dump(by_alias=True, mode="json"), json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
) )
response.raise_for_status() response.raise_for_status()
logger.debug( logger.debug(

View File

@@ -70,7 +70,7 @@ async def stdio_client(server: StdioServerParameters):
try: try:
async with write_stream_reader: async with write_stream_reader:
async for message in write_stream_reader: async for message in write_stream_reader:
json = message.model_dump_json(by_alias=True) json = message.model_dump_json(by_alias=True, exclude_none=True)
await process.stdin.send((json + "\n").encode()) await process.stdin.send((json + "\n").encode())
except anyio.ClosedResourceError: except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint() await anyio.lowlevel.checkpoint()

View File

@@ -74,7 +74,7 @@ class SseServerTransport:
await sse_stream_writer.send( await sse_stream_writer.send(
{ {
"event": "message", "event": "message",
"data": message.model_dump_json(by_alias=True), "data": message.model_dump_json(by_alias=True, exclude_none=True),
} }
) )

View File

@@ -48,7 +48,7 @@ async def stdio_server(
try: try:
async with write_stream_reader: async with write_stream_reader:
async for message in write_stream_reader: async for message in write_stream_reader:
json = message.model_dump_json(by_alias=True) json = message.model_dump_json(by_alias=True, exclude_none=True)
await stdout.write(json + "\n") await stdout.write(json + "\n")
await stdout.flush() await stdout.flush()
except anyio.ClosedResourceError: except anyio.ClosedResourceError:

View File

@@ -47,7 +47,7 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
try: try:
async with write_stream_reader: async with write_stream_reader:
async for message in write_stream_reader: async for message in write_stream_reader:
obj = message.model_dump(by_alias=True, mode="json") obj = message.model_dump(by_alias=True, mode="json", exclude_none=True)
await websocket.send_json(obj) await websocket.send_json(obj)
except anyio.ClosedResourceError: except anyio.ClosedResourceError:
await websocket.close() await websocket.close()

View File

@@ -132,7 +132,7 @@ class BaseSession(
self._response_streams[request_id] = response_stream self._response_streams[request_id] = response_stream
jsonrpc_request = JSONRPCRequest( jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0", id=request_id, **request.model_dump(by_alias=True, mode="json") jsonrpc="2.0", id=request_id, **request.model_dump(by_alias=True, mode="json", exclude_none=True)
) )
# TODO: Support progress callbacks # TODO: Support progress callbacks
@@ -150,7 +150,7 @@ class BaseSession(
Emits a notification, which is a one-way message that does not expect a response. Emits a notification, which is a one-way message that does not expect a response.
""" """
jsonrpc_notification = JSONRPCNotification( jsonrpc_notification = JSONRPCNotification(
jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json") jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True)
) )
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification)) await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
@@ -165,7 +165,7 @@ class BaseSession(
jsonrpc_response = JSONRPCResponse( jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
id=request_id, id=request_id,
result=response.model_dump(by_alias=True, mode="json"), result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
) )
await self._write_stream.send(JSONRPCMessage(jsonrpc_response)) await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
@@ -180,7 +180,7 @@ class BaseSession(
await self._incoming_message_stream_writer.send(message) await self._incoming_message_stream_writer.send(message)
elif isinstance(message.root, JSONRPCRequest): elif isinstance(message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate( validated_request = self._receive_request_type.model_validate(
message.root.model_dump(by_alias=True, mode="json") message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
) )
responder = RequestResponder( responder = RequestResponder(
request_id=message.root.id, request_id=message.root.id,
@@ -196,7 +196,7 @@ class BaseSession(
await self._incoming_message_stream_writer.send(responder) await self._incoming_message_stream_writer.send(responder)
elif isinstance(message.root, JSONRPCNotification): elif isinstance(message.root, JSONRPCNotification):
notification = self._receive_notification_type.model_validate( notification = self._receive_notification_type.model_validate(
message.root.model_dump(by_alias=True, mode="json") message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
) )
await self._received_notification(notification) await self._received_notification(notification)

View File

@@ -35,7 +35,7 @@ async def test_client_session_initialize():
jsonrpc_request = await client_to_server_receive.receive() jsonrpc_request = await client_to_server_receive.receive()
assert isinstance(jsonrpc_request.root, JSONRPCRequest) assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate( request = ClientRequest.model_validate(
jsonrpc_request.model_dump(by_alias=True, mode="json") jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
) )
assert isinstance(request.root, InitializeRequest) assert isinstance(request.root, InitializeRequest)
@@ -59,14 +59,14 @@ async def test_client_session_initialize():
JSONRPCResponse( JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
id=jsonrpc_request.root.id, id=jsonrpc_request.root.id,
result=result.model_dump(by_alias=True, mode="json"), result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
) )
) )
) )
jsonrpc_notification = await client_to_server_receive.receive() jsonrpc_notification = await client_to_server_receive.receive()
assert isinstance(jsonrpc_notification.root, JSONRPCNotification) assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
initialized_notification = ClientNotification.model_validate( initialized_notification = ClientNotification.model_validate(
jsonrpc_notification.model_dump(by_alias=True, mode="json") jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True)
) )
async def listen_session(): async def listen_session():

View File

@@ -18,7 +18,7 @@ async def test_stdio_server():
] ]
for message in messages: for message in messages:
stdin.write(message.model_dump_json() + "\n") stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n")
stdin.seek(0) stdin.seek(0)
async with stdio_server( async with stdio_server(

View File

@@ -15,7 +15,7 @@ def test_jsonrpc_request():
request = JSONRPCMessage.model_validate(json_data) request = JSONRPCMessage.model_validate(json_data)
assert isinstance(request.root, JSONRPCRequest) assert isinstance(request.root, JSONRPCRequest)
ClientRequest.model_validate(request.model_dump(by_alias=True)) ClientRequest.model_validate(request.model_dump(by_alias=True, exclude_none=True))
assert request.root.jsonrpc == "2.0" assert request.root.jsonrpc == "2.0"
assert request.root.id == 1 assert request.root.id == 1