add type hints

This commit is contained in:
Nick Merrill
2025-01-14 11:54:09 -05:00
parent 3fa26a5a97
commit aa7869a62f

View File

@@ -2,11 +2,12 @@ import multiprocessing
import socket import socket
import time import time
import anyio import anyio
from starlette.requests import Request
import uvicorn import uvicorn
import pytest import pytest
from pydantic import AnyUrl from pydantic import AnyUrl
import httpx import httpx
from typing import AsyncGenerator from typing import AsyncGenerator, Generator
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.routing import Mount, Route from starlette.routing import Mount, Route
@@ -56,7 +57,7 @@ class TestServer(Server):
) )
@self.list_tools() @self.list_tools()
async def handle_list_tools(): async def handle_list_tools() -> list[Tool]:
return [ return [
Tool( Tool(
name="test_tool", name="test_tool",
@@ -66,7 +67,7 @@ class TestServer(Server):
] ]
@self.call_tool() @self.call_tool()
async def handle_call_tool(name: str, args: dict): async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
return [TextContent(type="text", text=f"Called {name}")] return [TextContent(type="text", text=f"Called {name}")]
@@ -76,7 +77,7 @@ def make_server_app() -> Starlette:
sse = SseServerTransport("/messages/") sse = SseServerTransport("/messages/")
server = TestServer() server = TestServer()
async def handle_sse(request): async def handle_sse(request: Request) -> None:
async with sse.connect_sse( async with sse.connect_sse(
request.scope, request.receive, request._send request.scope, request.receive, request._send
) as streams: ) as streams:
@@ -94,14 +95,7 @@ def make_server_app() -> Starlette:
return app return app
@pytest.fixture(autouse=True) def run_server(server_port: int) -> None:
def space_around_test():
time.sleep(0.1)
yield
time.sleep(0.1)
def run_server(server_port: int):
app = make_server_app() app = make_server_app()
server = uvicorn.Server( server = uvicorn.Server(
config=uvicorn.Config( config=uvicorn.Config(
@@ -118,7 +112,7 @@ def run_server(server_port: int):
@pytest.fixture() @pytest.fixture()
def server(server_port: int): def server(server_port: int) -> Generator[None, None, None]:
proc = multiprocessing.Process( proc = multiprocessing.Process(
target=run_server, kwargs={"server_port": server_port}, daemon=True target=run_server, kwargs={"server_port": server_port}, daemon=True
) )
@@ -161,11 +155,11 @@ async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, N
# Tests # Tests
@pytest.mark.anyio @pytest.mark.anyio
async def test_raw_sse_connection(http_client: httpx.AsyncClient): async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None:
"""Test the SSE connection establishment simply with an HTTP client.""" """Test the SSE connection establishment simply with an HTTP client."""
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
async def connection_test(): async def connection_test() -> None:
async with http_client.stream("GET", "/sse") as response: async with http_client.stream("GET", "/sse") as response:
assert response.status_code == 200 assert response.status_code == 200
assert ( assert (
@@ -189,7 +183,7 @@ async def test_raw_sse_connection(http_client: httpx.AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_sse_client_basic_connection(server, server_url): async def test_sse_client_basic_connection(server: None, server_url: str) -> None:
async with sse_client(server_url + "/sse") as streams: async with sse_client(server_url + "/sse") as streams:
async with ClientSession(*streams) as session: async with ClientSession(*streams) as session:
# Test initialization # Test initialization
@@ -215,7 +209,7 @@ async def initialized_sse_client_session(
@pytest.mark.anyio @pytest.mark.anyio
async def test_sse_client_happy_request_and_response( async def test_sse_client_happy_request_and_response(
initialized_sse_client_session: ClientSession, initialized_sse_client_session: ClientSession,
): ) -> None:
session = initialized_sse_client_session session = initialized_sse_client_session
response = await session.read_resource(uri=AnyUrl("foobar://should-work")) response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
assert len(response.contents) == 1 assert len(response.contents) == 1
@@ -226,7 +220,7 @@ async def test_sse_client_happy_request_and_response(
@pytest.mark.anyio @pytest.mark.anyio
async def test_sse_client_exception_handling( async def test_sse_client_exception_handling(
initialized_sse_client_session: ClientSession, initialized_sse_client_session: ClientSession,
): ) -> None:
session = initialized_sse_client_session session = initialized_sse_client_session
with pytest.raises(McpError, match="OOPS! no resource with that URI was found"): with pytest.raises(McpError, match="OOPS! no resource with that URI was found"):
await session.read_resource(uri=AnyUrl("xxx://will-not-work")) await session.read_resource(uri=AnyUrl("xxx://will-not-work"))