refactor: Make types.py strictly typechecked. (#336)

This commit is contained in:
David Soria Parra
2025-03-26 14:21:35 +00:00
committed by GitHub
parent df2d3a57c2
commit 9a2bb6a7e7
4 changed files with 66 additions and 34 deletions

View File

@@ -87,7 +87,6 @@ include = ["src/mcp", "tests"]
venvPath = "."
venv = ".venv"
strict = ["src/mcp/**/*.py"]
exclude = ["src/mcp/types.py"]
[tool.ruff.lint]
select = ["E", "F", "I", "UP"]

View File

@@ -64,8 +64,10 @@ class NotificationParams(BaseModel):
"""
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams)
NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams)
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None)
NotificationParamsT = TypeVar(
"NotificationParamsT", bound=NotificationParams | dict[str, Any] | None
)
MethodT = TypeVar("MethodT", bound=str)
@@ -113,15 +115,16 @@ class PaginatedResult(Result):
"""
class JSONRPCRequest(Request):
class JSONRPCRequest(Request[dict[str, Any] | None, str]):
"""A request that expects a response."""
jsonrpc: Literal["2.0"]
id: RequestId
method: str
params: dict[str, Any] | None = None
class JSONRPCNotification(Notification):
class JSONRPCNotification(Notification[dict[str, Any] | None, str]):
"""A notification which does not expect a response."""
jsonrpc: Literal["2.0"]
@@ -277,7 +280,7 @@ class InitializeRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")
class InitializeRequest(Request):
class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]]):
"""
This request is sent from the client to the server when it first connects, asking it
to begin initialization.
@@ -298,7 +301,9 @@ class InitializeResult(Result):
"""Instructions describing how to use the server and its features."""
class InitializedNotification(Notification):
class InitializedNotification(
Notification[NotificationParams | None, Literal["notifications/initialized"]]
):
"""
This notification is sent from the client to the server after initialization has
finished.
@@ -308,7 +313,7 @@ class InitializedNotification(Notification):
params: NotificationParams | None = None
class PingRequest(Request):
class PingRequest(Request[RequestParams | None, Literal["ping"]]):
"""
A ping, issued by either the server or the client, to check that the other party is
still alive.
@@ -336,7 +341,9 @@ class ProgressNotificationParams(NotificationParams):
model_config = ConfigDict(extra="allow")
class ProgressNotification(Notification):
class ProgressNotification(
Notification[ProgressNotificationParams, Literal["notifications/progress"]]
):
"""
An out-of-band notification used to inform the receiver of a progress update for a
long-running request.
@@ -346,7 +353,9 @@ class ProgressNotification(Notification):
params: ProgressNotificationParams
class ListResourcesRequest(PaginatedRequest):
class ListResourcesRequest(
PaginatedRequest[RequestParams | None, Literal["resources/list"]]
):
"""Sent from the client to request a list of resources the server has."""
method: Literal["resources/list"]
@@ -408,7 +417,9 @@ class ListResourcesResult(PaginatedResult):
resources: list[Resource]
class ListResourceTemplatesRequest(PaginatedRequest):
class ListResourceTemplatesRequest(
PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]
):
"""Sent from the client to request a list of resource templates the server has."""
method: Literal["resources/templates/list"]
@@ -432,7 +443,9 @@ class ReadResourceRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")
class ReadResourceRequest(Request):
class ReadResourceRequest(
Request[ReadResourceRequestParams, Literal["resources/read"]]
):
"""Sent from the client to the server, to read a specific resource URI."""
method: Literal["resources/read"]
@@ -472,7 +485,11 @@ class ReadResourceResult(Result):
contents: list[TextResourceContents | BlobResourceContents]
class ResourceListChangedNotification(Notification):
class ResourceListChangedNotification(
Notification[
NotificationParams | None, Literal["notifications/resources/list_changed"]
]
):
"""
An optional notification from the server to the client, informing it that the list
of resources it can read from has changed.
@@ -493,7 +510,7 @@ class SubscribeRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")
class SubscribeRequest(Request):
class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscribe"]]):
"""
Sent from the client to request resources/updated notifications from the server
whenever a particular resource changes.
@@ -511,7 +528,9 @@ class UnsubscribeRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")
class UnsubscribeRequest(Request):
class UnsubscribeRequest(
Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]]
):
"""
Sent from the client to request cancellation of resources/updated notifications from
the server.
@@ -532,7 +551,11 @@ class ResourceUpdatedNotificationParams(NotificationParams):
model_config = ConfigDict(extra="allow")
class ResourceUpdatedNotification(Notification):
class ResourceUpdatedNotification(
Notification[
ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"]
]
):
"""
A notification from the server to the client, informing it that a resource has
changed and may need to be read again.
@@ -542,7 +565,9 @@ class ResourceUpdatedNotification(Notification):
params: ResourceUpdatedNotificationParams
class ListPromptsRequest(PaginatedRequest):
class ListPromptsRequest(
PaginatedRequest[RequestParams | None, Literal["prompts/list"]]
):
"""Sent from the client to request a list of prompts and prompt templates."""
method: Literal["prompts/list"]
@@ -589,7 +614,7 @@ class GetPromptRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")
class GetPromptRequest(Request):
class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]):
"""Used by the client to get a prompt provided by the server."""
method: Literal["prompts/get"]
@@ -659,7 +684,11 @@ class GetPromptResult(Result):
messages: list[PromptMessage]
class PromptListChangedNotification(Notification):
class PromptListChangedNotification(
Notification[
NotificationParams | None, Literal["notifications/prompts/list_changed"]
]
):
"""
An optional notification from the server to the client, informing it that the list
of prompts it offers has changed.
@@ -669,7 +698,7 @@ class PromptListChangedNotification(Notification):
params: NotificationParams | None = None
class ListToolsRequest(PaginatedRequest):
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
"""Sent from the client to request a list of tools the server has."""
method: Literal["tools/list"]
@@ -702,7 +731,7 @@ class CallToolRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")
class CallToolRequest(Request):
class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
"""Used by the client to invoke a tool provided by the server."""
method: Literal["tools/call"]
@@ -716,7 +745,9 @@ class CallToolResult(Result):
isError: bool = False
class ToolListChangedNotification(Notification):
class ToolListChangedNotification(
Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]]
):
"""
An optional notification from the server to the client, informing it that the list
of tools it offers has changed.
@@ -739,7 +770,7 @@ class SetLevelRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")
class SetLevelRequest(Request):
class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]):
"""A request from the client to the server, to enable or adjust logging."""
method: Literal["logging/setLevel"]
@@ -761,7 +792,9 @@ class LoggingMessageNotificationParams(NotificationParams):
model_config = ConfigDict(extra="allow")
class LoggingMessageNotification(Notification):
class LoggingMessageNotification(
Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]
):
"""Notification of a log message passed from server to client."""
method: Literal["notifications/message"]
@@ -856,7 +889,9 @@ class CreateMessageRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")
class CreateMessageRequest(Request):
class CreateMessageRequest(
Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]
):
"""A request from the server to sample an LLM via the client."""
method: Literal["sampling/createMessage"]
@@ -913,7 +948,7 @@ class CompleteRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")
class CompleteRequest(Request):
class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]):
"""A request from the client to the server, to ask for completion options."""
method: Literal["completion/complete"]
@@ -944,7 +979,7 @@ class CompleteResult(Result):
completion: Completion
class ListRootsRequest(Request):
class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]):
"""
Sent from the server to request a list of root URIs from the client. Roots allow
servers to ask for specific directories or files to operate on. A common example
@@ -987,7 +1022,9 @@ class ListRootsResult(Result):
roots: list[Root]
class RootsListChangedNotification(Notification):
class RootsListChangedNotification(
Notification[NotificationParams | None, Literal["notifications/roots/list_changed"]]
):
"""
A notification from the client to the server, informing it that the list of
roots has changed.

View File

@@ -138,9 +138,7 @@ def server(server_port: int) -> Generator[None, None, None]:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Server failed to start after {max_attempts} attempts"
)
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
yield

View File

@@ -134,9 +134,7 @@ def server(server_port: int) -> Generator[None, None, None]:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Server failed to start after {max_attempts} attempts"
)
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
yield