Close unclosed resources in the whole project (#267)

* Close resources

* Close all resources

* Update pyproject.toml

* Close all resources

* Close all resources

* try now...

* try to ignore this

* try again

* try adding one more..

* try now

* try now

* revert ci changes
This commit is contained in:
Marcelo Trylesinski
2025-03-13 11:59:45 +01:00
committed by GitHub
parent 1691b905e2
commit 94d326dbf1
8 changed files with 64 additions and 11 deletions

View File

@@ -67,6 +67,9 @@ packages = ["src/mcp"]
include = ["src/mcp", "tests"] include = ["src/mcp", "tests"]
venvPath = "." venvPath = "."
venv = ".venv" venv = ".venv"
strict = [
"src/mcp/server/fastmcp/tools/base.py",
]
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I"] select = ["E", "F", "I"]
@@ -85,3 +88,13 @@ members = ["examples/servers/*"]
[tool.uv.sources] [tool.uv.sources]
mcp = { workspace = true } mcp = { workspace = true }
[tool.pytest.ini_options]
xfail_strict = true
filterwarnings = [
"error",
# This should be fixed on Uvicorn's side.
"ignore::DeprecationWarning:websockets",
"ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning",
"ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel"
]

View File

@@ -43,7 +43,9 @@ async def _default_list_roots_callback(
) )
ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData) ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
types.ClientResult | types.ErrorData
)
class ClientSession( class ClientSession(
@@ -219,7 +221,7 @@ class ClientSession(
) )
async def call_tool( async def call_tool(
self, name: str, arguments: dict | None = None self, name: str, arguments: dict[str, Any] | None = None
) -> types.CallToolResult: ) -> types.CallToolResult:
"""Send a tools/call request.""" """Send a tools/call request."""
return await self.send_request( return await self.send_request(
@@ -258,7 +260,9 @@ class ClientSession(
) )
async def complete( async def complete(
self, ref: types.ResourceReference | types.PromptReference, argument: dict self,
ref: types.ResourceReference | types.PromptReference,
argument: dict[str, str],
) -> types.CompleteResult: ) -> types.CompleteResult:
"""Send a completion/complete request.""" """Send a completion/complete request."""
return await self.send_request( return await self.send_request(

View File

@@ -18,10 +18,10 @@ if TYPE_CHECKING:
class Tool(BaseModel): class Tool(BaseModel):
"""Internal tool registration info.""" """Internal tool registration info."""
fn: Callable = Field(exclude=True) fn: Callable[..., Any] = Field(exclude=True)
name: str = Field(description="Name of the tool") name: str = Field(description="Name of the tool")
description: str = Field(description="Description of what the tool does") description: str = Field(description="Description of what the tool does")
parameters: dict = Field(description="JSON schema for tool parameters") parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
fn_metadata: FuncMetadata = Field( fn_metadata: FuncMetadata = Field(
description="Metadata about the function including a pydantic model for tool" description="Metadata about the function including a pydantic model for tool"
" arguments" " arguments"
@@ -34,7 +34,7 @@ class Tool(BaseModel):
@classmethod @classmethod
def from_function( def from_function(
cls, cls,
fn: Callable, fn: Callable[..., Any],
name: str | None = None, name: str | None = None,
description: str | None = None, description: str | None = None,
context_kwarg: str | None = None, context_kwarg: str | None = None,

View File

@@ -102,7 +102,9 @@ class FuncMetadata(BaseModel):
) )
def func_metadata(func: Callable, skip_names: Sequence[str] = ()) -> FuncMetadata: def func_metadata(
func: Callable[..., Any], skip_names: Sequence[str] = ()
) -> FuncMetadata:
"""Given a function, return metadata including a pydantic model representing its """Given a function, return metadata including a pydantic model representing its
signature. signature.

View File

@@ -1,4 +1,5 @@
import logging import logging
from contextlib import AsyncExitStack
from datetime import timedelta from datetime import timedelta
from typing import Any, Callable, Generic, TypeVar from typing import Any, Callable, Generic, TypeVar
@@ -180,6 +181,7 @@ class BaseSession(
self._read_timeout_seconds = read_timeout_seconds self._read_timeout_seconds = read_timeout_seconds
self._in_flight = {} self._in_flight = {}
self._exit_stack = AsyncExitStack()
self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[ anyio.create_memory_object_stream[
RequestResponder[ReceiveRequestT, SendResultT] RequestResponder[ReceiveRequestT, SendResultT]
@@ -187,6 +189,12 @@ class BaseSession(
| Exception | Exception
]() ]()
) )
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_reader.aclose()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_writer.aclose()
)
async def __aenter__(self) -> Self: async def __aenter__(self) -> Self:
self._task_group = anyio.create_task_group() self._task_group = anyio.create_task_group()
@@ -195,6 +203,7 @@ class BaseSession(
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._exit_stack.aclose()
# Using BaseSession as a context manager should not block on exit (this # Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks # would be very surprising behavior), so make sure to cancel the tasks
# in the task group. # in the task group.
@@ -222,6 +231,9 @@ class BaseSession(
](1) ](1)
self._response_streams[request_id] = response_stream self._response_streams[request_id] = response_stream
self._exit_stack.push_async_callback(lambda: response_stream.aclose())
self._exit_stack.push_async_callback(lambda: response_stream_reader.aclose())
jsonrpc_request = JSONRPCRequest( jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0", jsonrpc="2.0",
id=request_id, id=request_id,

View File

@@ -83,6 +83,10 @@ async def test_client_session_initialize():
async with ( async with (
ClientSession(server_to_client_receive, client_to_server_send) as session, ClientSession(server_to_client_receive, client_to_server_send) as session,
anyio.create_task_group() as tg, anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
): ):
tg.start_soon(mock_server) tg.start_soon(mock_server)
tg.start_soon(listen_session) tg.start_soon(listen_session)

View File

@@ -43,7 +43,13 @@ async def test_request_id_match() -> None:
) )
# Start server task # Start server task
async with anyio.create_task_group() as tg: async with (
anyio.create_task_group() as tg,
client_writer,
client_reader,
server_writer,
server_reader,
):
tg.start_soon(run_server) tg.start_soon(run_server)
# Send initialize request # Send initialize request

View File

@@ -25,7 +25,7 @@ async def test_lowlevel_server_lifespan():
"""Test that lifespan works in low-level server.""" """Test that lifespan works in low-level server."""
@asynccontextmanager @asynccontextmanager
async def test_lifespan(server: Server) -> AsyncIterator[dict]: async def test_lifespan(server: Server) -> AsyncIterator[dict[str, bool]]:
"""Test lifespan context that tracks startup/shutdown.""" """Test lifespan context that tracks startup/shutdown."""
context = {"started": False, "shutdown": False} context = {"started": False, "shutdown": False}
try: try:
@@ -50,7 +50,13 @@ async def test_lowlevel_server_lifespan():
return [{"type": "text", "text": "true"}] return [{"type": "text", "text": "true"}]
# Run server in background task # Run server in background task
async with anyio.create_task_group() as tg: async with (
anyio.create_task_group() as tg,
send_stream1,
receive_stream1,
send_stream2,
receive_stream2,
):
async def run_server(): async def run_server():
await server.run( await server.run(
@@ -147,7 +153,13 @@ async def test_fastmcp_server_lifespan():
return True return True
# Run server in background task # Run server in background task
async with anyio.create_task_group() as tg: async with (
anyio.create_task_group() as tg,
send_stream1,
receive_stream1,
send_stream2,
receive_stream2,
):
async def run_server(): async def run_server():
await server._mcp_server.run( await server._mcp_server.run(