Use 120 characters instead of 88 (#856)

This commit is contained in:
Marcelo Trylesinski
2025-06-11 02:45:50 -07:00
committed by GitHub
parent f7265f7b91
commit 543961968c
90 changed files with 687 additions and 2142 deletions

View File

@@ -116,18 +116,14 @@ def no_expiry_access_token() -> AccessToken:
class TestBearerAuthBackend:
"""Tests for the BearerAuthBackend class."""
async def test_no_auth_header(
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
):
async def test_no_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
"""Test authentication with no Authorization header."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
request = Request({"type": "http", "headers": []})
result = await backend.authenticate(request)
assert result is None
async def test_non_bearer_auth_header(
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
):
async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
"""Test authentication with non-Bearer Authorization header."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
request = Request(
@@ -139,9 +135,7 @@ class TestBearerAuthBackend:
result = await backend.authenticate(request)
assert result is None
async def test_invalid_token(
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
):
async def test_invalid_token(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
"""Test authentication with invalid token."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
request = Request(
@@ -160,9 +154,7 @@ class TestBearerAuthBackend:
):
"""Test authentication with expired token."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
add_token_to_provider(
mock_oauth_provider, "expired_token", expired_access_token
)
add_token_to_provider(mock_oauth_provider, "expired_token", expired_access_token)
request = Request(
{
"type": "http",
@@ -203,9 +195,7 @@ class TestBearerAuthBackend:
):
"""Test authentication with token that has no expiry."""
backend = BearerAuthBackend(provider=mock_oauth_provider)
add_token_to_provider(
mock_oauth_provider, "no_expiry_token", no_expiry_access_token
)
add_token_to_provider(mock_oauth_provider, "no_expiry_token", no_expiry_access_token)
request = Request(
{
"type": "http",

View File

@@ -128,16 +128,12 @@ class TestRegistrationErrorHandling:
class TestAuthorizeErrorHandling:
@pytest.mark.anyio
async def test_authorize_error_handling(
self, client, oauth_provider, registered_client, pkce_challenge
):
async def test_authorize_error_handling(self, client, oauth_provider, registered_client, pkce_challenge):
# Mock the authorize method to raise an authorize error
with unittest.mock.patch.object(
oauth_provider,
"authorize",
side_effect=AuthorizeError(
error="access_denied", error_description="The user denied the request"
),
side_effect=AuthorizeError(error="access_denied", error_description="The user denied the request"),
):
# Register the client
client_id = registered_client["client_id"]
@@ -169,9 +165,7 @@ class TestAuthorizeErrorHandling:
class TestTokenErrorHandling:
@pytest.mark.anyio
async def test_token_error_handling_auth_code(
self, client, oauth_provider, registered_client, pkce_challenge
):
async def test_token_error_handling_auth_code(self, client, oauth_provider, registered_client, pkce_challenge):
# Register the client and get an auth code
client_id = registered_client["client_id"]
client_secret = registered_client["client_secret"]
@@ -224,9 +218,7 @@ class TestTokenErrorHandling:
assert data["error_description"] == "The authorization code is invalid"
@pytest.mark.anyio
async def test_token_error_handling_refresh_token(
self, client, oauth_provider, registered_client, pkce_challenge
):
async def test_token_error_handling_refresh_token(self, client, oauth_provider, registered_client, pkce_challenge):
# Register the client and get tokens
client_id = registered_client["client_id"]
client_secret = registered_client["client_secret"]

View File

@@ -47,9 +47,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
async def register_client(self, client_info: OAuthClientInformationFull):
self.clients[client_info.client_id] = client_info
async def authorize(
self, client: OAuthClientInformationFull, params: AuthorizationParams
) -> str:
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
# toy authorize implementation which just immediately generates an authorization
# code and completes the redirect
code = AuthorizationCode(
@@ -63,9 +61,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
)
self.auth_codes[code.code] = code
return construct_redirect_uri(
str(params.redirect_uri), code=code.code, state=params.state
)
return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state)
async def load_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: str
@@ -102,9 +98,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
refresh_token=refresh_token,
)
async def load_refresh_token(
self, client: OAuthClientInformationFull, refresh_token: str
) -> RefreshToken | None:
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None:
old_access_token = self.refresh_tokens.get(refresh_token)
if old_access_token is None:
return None
@@ -224,9 +218,7 @@ def auth_app(mock_oauth_provider):
@pytest.fixture
async def test_client(auth_app):
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com"
) as client:
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com") as client:
yield client
@@ -261,11 +253,7 @@ async def registered_client(test_client: httpx.AsyncClient, request):
def pkce_challenge():
"""Create a PKCE challenge with code_verifier and code_challenge."""
code_verifier = "some_random_verifier_string"
code_challenge = (
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
.decode()
.rstrip("=")
)
code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().rstrip("=")
return {"code_verifier": code_verifier, "code_challenge": code_challenge}
@@ -356,17 +344,13 @@ class TestAuthEndpoints:
metadata = response.json()
assert metadata["issuer"] == "https://auth.example.com/"
assert (
metadata["authorization_endpoint"] == "https://auth.example.com/authorize"
)
assert metadata["authorization_endpoint"] == "https://auth.example.com/authorize"
assert metadata["token_endpoint"] == "https://auth.example.com/token"
assert metadata["registration_endpoint"] == "https://auth.example.com/register"
assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke"
assert metadata["response_types_supported"] == ["code"]
assert metadata["code_challenge_methods_supported"] == ["S256"]
assert metadata["token_endpoint_auth_methods_supported"] == [
"client_secret_post"
]
assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post"]
assert metadata["grant_types_supported"] == [
"authorization_code",
"refresh_token",
@@ -386,14 +370,10 @@ class TestAuthEndpoints:
)
error_response = response.json()
assert error_response["error"] == "invalid_request"
assert (
"error_description" in error_response
) # Contains validation error messages
assert "error_description" in error_response # Contains validation error messages
@pytest.mark.anyio
async def test_token_invalid_auth_code(
self, test_client, registered_client, pkce_challenge
):
async def test_token_invalid_auth_code(self, test_client, registered_client, pkce_challenge):
"""Test token endpoint error - authorization code does not exist."""
# Try to use a non-existent authorization code
response = await test_client.post(
@@ -413,9 +393,7 @@ class TestAuthEndpoints:
assert response.status_code == 400
error_response = response.json()
assert error_response["error"] == "invalid_grant"
assert (
"authorization code does not exist" in error_response["error_description"]
)
assert "authorization code does not exist" in error_response["error_description"]
@pytest.mark.anyio
async def test_token_expired_auth_code(
@@ -458,9 +436,7 @@ class TestAuthEndpoints:
assert response.status_code == 400
error_response = response.json()
assert error_response["error"] == "invalid_grant"
assert (
"authorization code has expired" in error_response["error_description"]
)
assert "authorization code has expired" in error_response["error_description"]
@pytest.mark.anyio
@pytest.mark.parametrize(
@@ -475,9 +451,7 @@ class TestAuthEndpoints:
],
indirect=True,
)
async def test_token_redirect_uri_mismatch(
self, test_client, registered_client, auth_code, pkce_challenge
):
async def test_token_redirect_uri_mismatch(self, test_client, registered_client, auth_code, pkce_challenge):
"""Test token endpoint error - redirect URI mismatch."""
# Try to use the code with a different redirect URI
response = await test_client.post(
@@ -498,9 +472,7 @@ class TestAuthEndpoints:
assert "redirect_uri did not match" in error_response["error_description"]
@pytest.mark.anyio
async def test_token_code_verifier_mismatch(
self, test_client, registered_client, auth_code
):
async def test_token_code_verifier_mismatch(self, test_client, registered_client, auth_code):
"""Test token endpoint error - PKCE code verifier mismatch."""
# Try to use the code with an incorrect code verifier
response = await test_client.post(
@@ -569,9 +541,7 @@ class TestAuthEndpoints:
# Step 2: Time travel forward 4 hours (tokens expire in 1 hour by default)
# Mock the time.time() function to return a value 4 hours in the future
with unittest.mock.patch(
"time.time", return_value=current_time + 14400
): # 4 hours = 14400 seconds
with unittest.mock.patch("time.time", return_value=current_time + 14400): # 4 hours = 14400 seconds
# Try to use the refresh token which should now be considered expired
response = await test_client.post(
"/token",
@@ -590,9 +560,7 @@ class TestAuthEndpoints:
assert "refresh token has expired" in error_response["error_description"]
@pytest.mark.anyio
async def test_token_invalid_scope(
self, test_client, registered_client, auth_code, pkce_challenge
):
async def test_token_invalid_scope(self, test_client, registered_client, auth_code, pkce_challenge):
"""Test token endpoint error - invalid scope in refresh token request."""
# Exchange authorization code for tokens
token_response = await test_client.post(
@@ -628,9 +596,7 @@ class TestAuthEndpoints:
assert "cannot request scope" in error_response["error_description"]
@pytest.mark.anyio
async def test_client_registration(
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider
):
async def test_client_registration(self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider):
"""Test client registration."""
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
@@ -656,9 +622,7 @@ class TestAuthEndpoints:
# ) is not None
@pytest.mark.anyio
async def test_client_registration_missing_required_fields(
self, test_client: httpx.AsyncClient
):
async def test_client_registration_missing_required_fields(self, test_client: httpx.AsyncClient):
"""Test client registration with missing required fields."""
# Missing redirect_uris which is a required field
client_metadata = {
@@ -677,9 +641,7 @@ class TestAuthEndpoints:
assert error_data["error_description"] == "redirect_uris: Field required"
@pytest.mark.anyio
async def test_client_registration_invalid_uri(
self, test_client: httpx.AsyncClient
):
async def test_client_registration_invalid_uri(self, test_client: httpx.AsyncClient):
"""Test client registration with invalid URIs."""
# Invalid redirect_uri format
client_metadata = {
@@ -696,14 +658,11 @@ class TestAuthEndpoints:
assert "error" in error_data
assert error_data["error"] == "invalid_client_metadata"
assert error_data["error_description"] == (
"redirect_uris.0: Input should be a valid URL, "
"relative URL without a base"
"redirect_uris.0: Input should be a valid URL, " "relative URL without a base"
)
@pytest.mark.anyio
async def test_client_registration_empty_redirect_uris(
self, test_client: httpx.AsyncClient
):
async def test_client_registration_empty_redirect_uris(self, test_client: httpx.AsyncClient):
"""Test client registration with empty redirect_uris array."""
client_metadata = {
"redirect_uris": [], # Empty array
@@ -719,8 +678,7 @@ class TestAuthEndpoints:
assert "error" in error_data
assert error_data["error"] == "invalid_client_metadata"
assert (
error_data["error_description"]
== "redirect_uris: List should have at least 1 item after validation, not 0"
error_data["error_description"] == "redirect_uris: List should have at least 1 item after validation, not 0"
)
@pytest.mark.anyio
@@ -875,12 +833,7 @@ class TestAuthEndpoints:
assert response.status_code == 200
# Verify that the token was revoked
assert (
await mock_oauth_provider.load_access_token(
new_token_response["access_token"]
)
is None
)
assert await mock_oauth_provider.load_access_token(new_token_response["access_token"]) is None
@pytest.mark.anyio
async def test_revoke_invalid_token(self, test_client, registered_client):
@@ -913,9 +866,7 @@ class TestAuthEndpoints:
assert "token_type_hint" in error_response["error_description"]
@pytest.mark.anyio
async def test_client_registration_disallowed_scopes(
self, test_client: httpx.AsyncClient
):
async def test_client_registration_disallowed_scopes(self, test_client: httpx.AsyncClient):
"""Test client registration with scopes that are not allowed."""
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
@@ -955,18 +906,14 @@ class TestAuthEndpoints:
assert client_info["scope"] == "read write"
# Retrieve the client from the store to verify default scopes
registered_client = await mock_oauth_provider.get_client(
client_info["client_id"]
)
registered_client = await mock_oauth_provider.get_client(client_info["client_id"])
assert registered_client is not None
# Check that default scopes were applied
assert registered_client.scope == "read write"
@pytest.mark.anyio
async def test_client_registration_invalid_grant_type(
self, test_client: httpx.AsyncClient
):
async def test_client_registration_invalid_grant_type(self, test_client: httpx.AsyncClient):
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
"client_name": "Test Client",
@@ -981,19 +928,14 @@ class TestAuthEndpoints:
error_data = response.json()
assert "error" in error_data
assert error_data["error"] == "invalid_client_metadata"
assert (
error_data["error_description"]
== "grant_types must be authorization_code and refresh_token"
)
assert error_data["error_description"] == "grant_types must be authorization_code and refresh_token"
class TestAuthorizeEndpointErrors:
"""Test error handling in the OAuth authorization endpoint."""
@pytest.mark.anyio
async def test_authorize_missing_client_id(
self, test_client: httpx.AsyncClient, pkce_challenge
):
async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, pkce_challenge):
"""Test authorization endpoint with missing client_id.
According to the OAuth2.0 spec, if client_id is missing, the server should
@@ -1017,9 +959,7 @@ class TestAuthorizeEndpointErrors:
assert "client_id" in response.text.lower()
@pytest.mark.anyio
async def test_authorize_invalid_client_id(
self, test_client: httpx.AsyncClient, pkce_challenge
):
async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, pkce_challenge):
"""Test authorization endpoint with invalid client_id.
According to the OAuth2.0 spec, if client_id is invalid, the server should
@@ -1202,9 +1142,7 @@ class TestAuthorizeEndpointErrors:
assert query_params["state"][0] == "test_state"
@pytest.mark.anyio
async def test_authorize_missing_pkce_challenge(
self, test_client: httpx.AsyncClient, registered_client
):
async def test_authorize_missing_pkce_challenge(self, test_client: httpx.AsyncClient, registered_client):
"""Test authorization endpoint with missing PKCE code_challenge.
Missing PKCE parameters should result in invalid_request error.
@@ -1233,9 +1171,7 @@ class TestAuthorizeEndpointErrors:
assert query_params["state"][0] == "test_state"
@pytest.mark.anyio
async def test_authorize_invalid_scope(
self, test_client: httpx.AsyncClient, registered_client, pkce_challenge
):
async def test_authorize_invalid_scope(self, test_client: httpx.AsyncClient, registered_client, pkce_challenge):
"""Test authorization endpoint with invalid scope.
Invalid scope should redirect with invalid_scope error.

View File

@@ -18,9 +18,7 @@ class TestRenderPrompt:
return "Hello, world!"
prompt = Prompt.from_function(fn)
assert await prompt.render() == [
UserMessage(content=TextContent(type="text", text="Hello, world!"))
]
assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))]
@pytest.mark.anyio
async def test_async_fn(self):
@@ -28,9 +26,7 @@ class TestRenderPrompt:
return "Hello, world!"
prompt = Prompt.from_function(fn)
assert await prompt.render() == [
UserMessage(content=TextContent(type="text", text="Hello, world!"))
]
assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))]
@pytest.mark.anyio
async def test_fn_with_args(self):
@@ -39,11 +35,7 @@ class TestRenderPrompt:
prompt = Prompt.from_function(fn)
assert await prompt.render(arguments={"name": "World"}) == [
UserMessage(
content=TextContent(
type="text", text="Hello, World! You're 30 years old."
)
)
UserMessage(content=TextContent(type="text", text="Hello, World! You're 30 years old."))
]
@pytest.mark.anyio
@@ -61,21 +53,15 @@ class TestRenderPrompt:
return UserMessage(content="Hello, world!")
prompt = Prompt.from_function(fn)
assert await prompt.render() == [
UserMessage(content=TextContent(type="text", text="Hello, world!"))
]
assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))]
@pytest.mark.anyio
async def test_fn_returns_assistant_message(self):
async def fn() -> AssistantMessage:
return AssistantMessage(
content=TextContent(type="text", text="Hello, world!")
)
return AssistantMessage(content=TextContent(type="text", text="Hello, world!"))
prompt = Prompt.from_function(fn)
assert await prompt.render() == [
AssistantMessage(content=TextContent(type="text", text="Hello, world!"))
]
assert await prompt.render() == [AssistantMessage(content=TextContent(type="text", text="Hello, world!"))]
@pytest.mark.anyio
async def test_fn_returns_multiple_messages(self):
@@ -156,9 +142,7 @@ class TestRenderPrompt:
prompt = Prompt.from_function(fn)
assert await prompt.render() == [
UserMessage(
content=TextContent(type="text", text="Please analyze this file:")
),
UserMessage(content=TextContent(type="text", text="Please analyze this file:")),
UserMessage(
content=EmbeddedResource(
type="resource",
@@ -169,9 +153,7 @@ class TestRenderPrompt:
),
)
),
AssistantMessage(
content=TextContent(type="text", text="I'll help analyze that file.")
),
AssistantMessage(content=TextContent(type="text", text="I'll help analyze that file.")),
]
@pytest.mark.anyio

View File

@@ -72,9 +72,7 @@ class TestPromptManager:
prompt = Prompt.from_function(fn)
manager.add_prompt(prompt)
messages = await manager.render_prompt("fn")
assert messages == [
UserMessage(content=TextContent(type="text", text="Hello, world!"))
]
assert messages == [UserMessage(content=TextContent(type="text", text="Hello, world!"))]
@pytest.mark.anyio
async def test_render_prompt_with_args(self):
@@ -87,9 +85,7 @@ class TestPromptManager:
prompt = Prompt.from_function(fn)
manager.add_prompt(prompt)
messages = await manager.render_prompt("fn", arguments={"name": "World"})
assert messages == [
UserMessage(content=TextContent(type="text", text="Hello, World!"))
]
assert messages == [UserMessage(content=TextContent(type="text", text="Hello, World!"))]
@pytest.mark.anyio
async def test_render_unknown_prompt(self):

View File

@@ -100,9 +100,7 @@ class TestFileResource:
with pytest.raises(ValueError, match="Error reading file"):
await resource.read()
@pytest.mark.skipif(
os.name == "nt", reason="File permissions behave differently on Windows"
)
@pytest.mark.skipif(os.name == "nt", reason="File permissions behave differently on Windows")
@pytest.mark.anyio
async def test_permission_error(self, temp_file: Path):
"""Test reading a file without permissions."""

View File

@@ -28,9 +28,7 @@ def complex_arguments_fn(
# list[str] | str is an interesting case because if it comes in as JSON like
# "[\"a\", \"b\"]" then it will be naively parsed as a string.
list_str_or_str: list[str] | str,
an_int_annotated_with_field: Annotated[
int, Field(description="An int with a field")
],
an_int_annotated_with_field: Annotated[int, Field(description="An int with a field")],
an_int_annotated_with_field_and_others: Annotated[
int,
str, # Should be ignored, really
@@ -42,9 +40,7 @@ def complex_arguments_fn(
"123",
456,
],
field_with_default_via_field_annotation_before_nondefault_arg: Annotated[
int, Field(1)
],
field_with_default_via_field_annotation_before_nondefault_arg: Annotated[int, Field(1)],
unannotated,
my_model_a: SomeInputModelA,
my_model_a_forward_ref: "SomeInputModelA",
@@ -179,9 +175,7 @@ def test_str_vs_list_str():
def test_skip_names():
"""Test that skipped parameters are not included in the model"""
def func_with_many_params(
keep_this: int, skip_this: str, also_keep: float, also_skip: bool
):
def func_with_many_params(keep_this: int, skip_this: str, also_keep: float, also_skip: bool):
return keep_this, skip_this, also_keep, also_skip
# Skip some parameters

View File

@@ -130,11 +130,7 @@ def make_everything_fastmcp() -> FastMCP:
# Request sampling from the client
result = await ctx.session.create_message(
messages=[
SamplingMessage(
role="user", content=TextContent(type="text", text=prompt)
)
],
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=prompt))],
max_tokens=100,
temperature=0.7,
)
@@ -278,11 +274,7 @@ def make_fastmcp_stateless_http_app():
def run_server(server_port: int) -> None:
"""Run the server."""
_, app = make_fastmcp_app()
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
print(f"Starting server on port {server_port}")
server.run()
@@ -290,11 +282,7 @@ def run_server(server_port: int) -> None:
def run_everything_legacy_sse_http_server(server_port: int) -> None:
"""Run the comprehensive server with all features."""
_, app = make_everything_fastmcp_app()
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
print(f"Starting comprehensive server on port {server_port}")
server.run()
@@ -302,11 +290,7 @@ def run_everything_legacy_sse_http_server(server_port: int) -> None:
def run_streamable_http_server(server_port: int) -> None:
"""Run the StreamableHTTP server."""
_, app = make_fastmcp_streamable_http_app()
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
print(f"Starting StreamableHTTP server on port {server_port}")
server.run()
@@ -314,11 +298,7 @@ def run_streamable_http_server(server_port: int) -> None:
def run_everything_server(server_port: int) -> None:
"""Run the comprehensive StreamableHTTP server with all features."""
_, app = make_everything_fastmcp_streamable_http_app()
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
print(f"Starting comprehensive StreamableHTTP server on port {server_port}")
server.run()
@@ -326,11 +306,7 @@ def run_everything_server(server_port: int) -> None:
def run_stateless_http_server(server_port: int) -> None:
"""Run the stateless StreamableHTTP server."""
_, app = make_fastmcp_stateless_http_app()
server = uvicorn.Server(
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
print(f"Starting stateless StreamableHTTP server on port {server_port}")
server.run()
@@ -369,9 +345,7 @@ def server(server_port: int) -> Generator[None, None, None]:
@pytest.fixture()
def streamable_http_server(http_server_port: int) -> Generator[None, None, None]:
"""Start the StreamableHTTP server in a separate process."""
proc = multiprocessing.Process(
target=run_streamable_http_server, args=(http_server_port,), daemon=True
)
proc = multiprocessing.Process(target=run_streamable_http_server, args=(http_server_port,), daemon=True)
print("Starting StreamableHTTP server process")
proc.start()
@@ -388,9 +362,7 @@ def streamable_http_server(http_server_port: int) -> Generator[None, None, None]
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"StreamableHTTP server failed to start after {max_attempts} attempts"
)
raise RuntimeError(f"StreamableHTTP server failed to start after {max_attempts} attempts")
yield
@@ -427,9 +399,7 @@ def stateless_http_server(
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Stateless server failed to start after {max_attempts} attempts"
)
raise RuntimeError(f"Stateless server failed to start after {max_attempts} attempts")
yield
@@ -459,9 +429,7 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
@pytest.mark.anyio
async def test_fastmcp_streamable_http(
streamable_http_server: None, http_server_url: str
) -> None:
async def test_fastmcp_streamable_http(streamable_http_server: None, http_server_url: str) -> None:
"""Test that FastMCP works with StreamableHTTP transport."""
# Connect to the server using StreamableHTTP
async with streamablehttp_client(http_server_url + "/mcp") as (
@@ -484,9 +452,7 @@ async def test_fastmcp_streamable_http(
@pytest.mark.anyio
async def test_fastmcp_stateless_streamable_http(
stateless_http_server: None, stateless_http_server_url: str
) -> None:
async def test_fastmcp_stateless_streamable_http(stateless_http_server: None, stateless_http_server_url: str) -> None:
"""Test that FastMCP works with stateless StreamableHTTP transport."""
# Connect to the server using StreamableHTTP
async with streamablehttp_client(stateless_http_server_url + "/mcp") as (
@@ -562,9 +528,7 @@ def everything_server(everything_server_port: int) -> Generator[None, None, None
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Comprehensive server failed to start after {max_attempts} attempts"
)
raise RuntimeError(f"Comprehensive server failed to start after {max_attempts} attempts")
yield
@@ -601,10 +565,7 @@ def everything_streamable_http_server(
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Comprehensive StreamableHTTP server failed to start after "
f"{max_attempts} attempts"
)
raise RuntimeError(f"Comprehensive StreamableHTTP server failed to start after " f"{max_attempts} attempts")
yield
@@ -648,9 +609,7 @@ class NotificationCollector:
await self.handle_tool_list_changed(message.root.params)
async def call_all_mcp_features(
session: ClientSession, collector: NotificationCollector
) -> None:
async def call_all_mcp_features(session: ClientSession, collector: NotificationCollector) -> None:
"""
Test all MCP features using the provided session.
@@ -680,9 +639,7 @@ async def call_all_mcp_features(
# Test progress callback functionality
progress_updates = []
async def progress_callback(
progress: float, total: float | None, message: str | None
) -> None:
async def progress_callback(progress: float, total: float | None, message: str | None) -> None:
"""Collect progress updates for testing (async version)."""
progress_updates.append((progress, total, message))
print(f"Progress: {progress}/{total} - {message}")
@@ -726,19 +683,12 @@ async def call_all_mcp_features(
# Verify we received log messages from the sampling tool
assert len(collector.log_messages) > 0
assert any(
"Requesting sampling for prompt" in msg.data for msg in collector.log_messages
)
assert any(
"Received sampling result from model" in msg.data
for msg in collector.log_messages
)
assert any("Requesting sampling for prompt" in msg.data for msg in collector.log_messages)
assert any("Received sampling result from model" in msg.data for msg in collector.log_messages)
# 4. Test notification tool
notification_message = "test_notifications"
notification_result = await session.call_tool(
"notification_tool", {"message": notification_message}
)
notification_result = await session.call_tool("notification_tool", {"message": notification_message})
assert len(notification_result.content) == 1
assert isinstance(notification_result.content[0], TextContent)
assert "Sent notifications and logs" in notification_result.content[0].text
@@ -773,36 +723,24 @@ async def call_all_mcp_features(
# 2. Dynamic resource
resource_category = "test"
dynamic_content = await session.read_resource(
AnyUrl(f"resource://dynamic/{resource_category}")
)
dynamic_content = await session.read_resource(AnyUrl(f"resource://dynamic/{resource_category}"))
assert isinstance(dynamic_content, ReadResourceResult)
assert len(dynamic_content.contents) == 1
assert isinstance(dynamic_content.contents[0], TextResourceContents)
assert (
f"Dynamic resource content for category: {resource_category}"
in dynamic_content.contents[0].text
)
assert f"Dynamic resource content for category: {resource_category}" in dynamic_content.contents[0].text
# 3. Template resource
resource_id = "456"
template_content = await session.read_resource(
AnyUrl(f"resource://template/{resource_id}/data")
)
template_content = await session.read_resource(AnyUrl(f"resource://template/{resource_id}/data"))
assert isinstance(template_content, ReadResourceResult)
assert len(template_content.contents) == 1
assert isinstance(template_content.contents[0], TextResourceContents)
assert (
f"Template resource data for ID: {resource_id}"
in template_content.contents[0].text
)
assert f"Template resource data for ID: {resource_id}" in template_content.contents[0].text
# Test prompts
# 1. Simple prompt
prompts = await session.list_prompts()
simple_prompt = next(
(p for p in prompts.prompts if p.name == "simple_prompt"), None
)
simple_prompt = next((p for p in prompts.prompts if p.name == "simple_prompt"), None)
assert simple_prompt is not None
prompt_topic = "AI"
@@ -812,16 +750,12 @@ async def call_all_mcp_features(
# The actual message structure depends on the prompt implementation
# 2. Complex prompt
complex_prompt = next(
(p for p in prompts.prompts if p.name == "complex_prompt"), None
)
complex_prompt = next((p for p in prompts.prompts if p.name == "complex_prompt"), None)
assert complex_prompt is not None
query = "What is AI?"
context = "technical"
complex_result = await session.get_prompt(
"complex_prompt", {"user_query": query, "context": context}
)
complex_result = await session.get_prompt("complex_prompt", {"user_query": query, "context": context})
assert isinstance(complex_result, GetPromptResult)
assert len(complex_result.messages) >= 1
@@ -837,9 +771,7 @@ async def call_all_mcp_features(
print(f"Received headers: {headers_data}")
# Test 6: Call tool that returns full context
context_result = await session.call_tool(
"echo_context", {"custom_request_id": "test-123"}
)
context_result = await session.call_tool("echo_context", {"custom_request_id": "test-123"})
assert len(context_result.content) == 1
assert isinstance(context_result.content[0], TextContent)
@@ -871,9 +803,7 @@ async def sampling_callback(
@pytest.mark.anyio
async def test_fastmcp_all_features_sse(
everything_server: None, everything_server_url: str
) -> None:
async def test_fastmcp_all_features_sse(everything_server: None, everything_server_url: str) -> None:
"""Test all MCP features work correctly with SSE transport."""
# Create notification collector

View File

@@ -59,9 +59,7 @@ class TestServer:
"""Test SSE app creation with different mount paths."""
# Test with default mount path
mcp = FastMCP()
with patch.object(
mcp, "_normalize_path", return_value="/messages/"
) as mock_normalize:
with patch.object(mcp, "_normalize_path", return_value="/messages/") as mock_normalize:
mcp.sse_app()
# Verify _normalize_path was called with correct args
mock_normalize.assert_called_once_with("/", "/messages/")
@@ -69,18 +67,14 @@ class TestServer:
# Test with custom mount path in settings
mcp = FastMCP()
mcp.settings.mount_path = "/custom"
with patch.object(
mcp, "_normalize_path", return_value="/custom/messages/"
) as mock_normalize:
with patch.object(mcp, "_normalize_path", return_value="/custom/messages/") as mock_normalize:
mcp.sse_app()
# Verify _normalize_path was called with correct args
mock_normalize.assert_called_once_with("/custom", "/messages/")
# Test with mount_path parameter
mcp = FastMCP()
with patch.object(
mcp, "_normalize_path", return_value="/param/messages/"
) as mock_normalize:
with patch.object(mcp, "_normalize_path", return_value="/param/messages/") as mock_normalize:
mcp.sse_app(mount_path="/param")
# Verify _normalize_path was called with correct args
mock_normalize.assert_called_once_with("/param", "/messages/")
@@ -103,9 +97,7 @@ class TestServer:
# Verify path values
assert sse_routes[0].path == "/sse", "SSE route path should be /sse"
assert (
mount_routes[0].path == "/messages"
), "Mount route path should be /messages"
assert mount_routes[0].path == "/messages", "Mount route path should be /messages"
# Test with mount path as parameter
mcp = FastMCP()
@@ -121,20 +113,14 @@ class TestServer:
# Verify path values
assert sse_routes[0].path == "/sse", "SSE route path should be /sse"
assert (
mount_routes[0].path == "/messages"
), "Mount route path should be /messages"
assert mount_routes[0].path == "/messages", "Mount route path should be /messages"
@pytest.mark.anyio
async def test_non_ascii_description(self):
"""Test that FastMCP handles non-ASCII characters in descriptions correctly"""
mcp = FastMCP()
@mcp.tool(
description=(
"🌟 This tool uses emojis and UTF-8 characters: á é í ó ú ñ 漢字 🎉"
)
)
@mcp.tool(description=("🌟 This tool uses emojis and UTF-8 characters: á é í ó ú ñ 漢字 🎉"))
def hello_world(name: str = "世界") -> str:
return f"¡Hola, {name}! 👋"
@@ -187,9 +173,7 @@ class TestServer:
async def test_add_resource_decorator_incorrect_usage(self):
mcp = FastMCP()
with pytest.raises(
TypeError, match="The @resource decorator was used incorrectly"
):
with pytest.raises(TypeError, match="The @resource decorator was used incorrectly"):
@mcp.resource # Missing parentheses #type: ignore
def get_data(x: str) -> str:
@@ -373,9 +357,7 @@ class TestServerResources:
def get_text():
return "Hello, world!"
resource = FunctionResource(
uri=AnyUrl("resource://test"), name="test", fn=get_text
)
resource = FunctionResource(uri=AnyUrl("resource://test"), name="test", fn=get_text)
mcp.add_resource(resource)
async with client_session(mcp._mcp_server) as client:
@@ -411,9 +393,7 @@ class TestServerResources:
text_file = tmp_path / "test.txt"
text_file.write_text("Hello from file!")
resource = FileResource(
uri=AnyUrl("file://test.txt"), name="test.txt", path=text_file
)
resource = FileResource(uri=AnyUrl("file://test.txt"), name="test.txt", path=text_file)
mcp.add_resource(resource)
async with client_session(mcp._mcp_server) as client:
@@ -440,10 +420,7 @@ class TestServerResources:
async with client_session(mcp._mcp_server) as client:
result = await client.read_resource(AnyUrl("file://test.bin"))
assert isinstance(result.contents[0], BlobResourceContents)
assert (
result.contents[0].blob
== base64.b64encode(b"Binary file data").decode()
)
assert result.contents[0].blob == base64.b64encode(b"Binary file data").decode()
@pytest.mark.anyio
async def test_function_resource(self):
@@ -532,9 +509,7 @@ class TestServerResourceTemplates:
return f"Data for {org}/{repo}"
async with client_session(mcp._mcp_server) as client:
result = await client.read_resource(
AnyUrl("resource://cursor/fastmcp/data")
)
result = await client.read_resource(AnyUrl("resource://cursor/fastmcp/data"))
assert isinstance(result.contents[0], TextResourceContents)
assert result.contents[0].text == "Data for cursor/fastmcp"

View File

@@ -147,9 +147,7 @@ class TestAddTools:
def test_add_lambda_with_no_name(self):
manager = ToolManager()
with pytest.raises(
ValueError, match="You must provide a name for lambda functions"
):
with pytest.raises(ValueError, match="You must provide a name for lambda functions"):
manager.add_tool(lambda x: x)
def test_warn_on_duplicate_tools(self, caplog):
@@ -346,9 +344,7 @@ class TestContextHandling:
tool = manager.add_tool(tool_without_context)
assert tool.context_kwarg is None
def tool_with_parametrized_context(
x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT]
) -> str:
def tool_with_parametrized_context(x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT]) -> str:
return str(x)
tool = manager.add_tool(tool_with_parametrized_context)

View File

@@ -10,13 +10,7 @@ from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder
from mcp.types import (
ClientResult,
ServerNotification,
ServerRequest,
Tool,
ToolAnnotations,
)
from mcp.types import ClientResult, ServerNotification, ServerRequest, Tool, ToolAnnotations
@pytest.mark.anyio
@@ -45,18 +39,12 @@ async def test_lowlevel_server_tool_annotations():
)
]
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](10)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](10)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10)
# Message handler for client
async def message_handler(
message: RequestResponder[ServerRequest, ClientResult]
| ServerNotification
| Exception,
message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message

View File

@@ -56,11 +56,7 @@ async def test_read_resource_binary(temp_file: Path):
@server.read_resource()
async def read_resource(uri: AnyUrl) -> Iterable[ReadResourceContents]:
return [
ReadResourceContents(
content=b"Hello World", mime_type="application/octet-stream"
)
]
return [ReadResourceContents(content=b"Hello World", mime_type="application/octet-stream")]
# Get the handler directly from the server
handler = server.request_handlers[types.ReadResourceRequest]

View File

@@ -20,18 +20,12 @@ from mcp.types import (
@pytest.mark.anyio
async def test_server_session_initialize():
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
# Create a message handler to catch exceptions
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message
@@ -54,9 +48,7 @@ async def test_server_session_initialize():
if isinstance(message, Exception):
raise message
if isinstance(message, ClientNotification) and isinstance(
message.root, InitializedNotification
):
if isinstance(message, ClientNotification) and isinstance(message.root, InitializedNotification):
received_initialized = True
return
@@ -111,12 +103,8 @@ async def test_server_capabilities():
@pytest.mark.anyio
async def test_server_session_initialize_with_older_protocol_version():
"""Test that server accepts and responds with older protocol (2024-11-05)."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage | Exception
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
received_initialized = False
received_protocol_version = None
@@ -137,9 +125,7 @@ async def test_server_session_initialize_with_older_protocol_version():
if isinstance(message, Exception):
raise message
if isinstance(message, types.ClientNotification) and isinstance(
message.root, InitializedNotification
):
if isinstance(message, types.ClientNotification) and isinstance(message.root, InitializedNotification):
received_initialized = True
return
@@ -157,9 +143,7 @@ async def test_server_session_initialize_with_older_protocol_version():
params=types.InitializeRequestParams(
protocolVersion="2024-11-05",
capabilities=types.ClientCapabilities(),
clientInfo=types.Implementation(
name="test-client", version="1.0.0"
),
clientInfo=types.Implementation(name="test-client", version="1.0.0"),
).model_dump(by_alias=True, mode="json", exclude_none=True),
)
)

View File

@@ -22,9 +22,10 @@ async def test_stdio_server():
stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n")
stdin.seek(0)
async with stdio_server(
stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)
) as (read_stream, write_stream):
async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as (
read_stream,
write_stream,
):
received_messages = []
async with read_stream:
async for message in read_stream:
@@ -36,12 +37,8 @@ async def test_stdio_server():
# Verify received messages
assert len(received_messages) == 2
assert received_messages[0] == JSONRPCMessage(
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
)
assert received_messages[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
)
assert received_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))
assert received_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}))
# Test sending responses from the server
responses = [
@@ -58,13 +55,7 @@ async def test_stdio_server():
output_lines = stdout.readlines()
assert len(output_lines) == 2
received_responses = [
JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines
]
received_responses = [JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines]
assert len(received_responses) == 2
assert received_responses[0] == JSONRPCMessage(
root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")
)
assert received_responses[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})
)
assert received_responses[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping"))
assert received_responses[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}))

View File

@@ -22,10 +22,7 @@ async def test_run_can_only_be_called_once():
async with manager.run():
pass
assert (
"StreamableHTTPSessionManager .run() can only be called once per instance"
in str(excinfo.value)
)
assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(excinfo.value)
@pytest.mark.anyio
@@ -51,10 +48,7 @@ async def test_run_prevents_concurrent_calls():
# One should succeed, one should fail
assert len(errors) == 1
assert (
"StreamableHTTPSessionManager .run() can only be called once per instance"
in str(errors[0])
)
assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(errors[0])
@pytest.mark.anyio
@@ -76,6 +70,4 @@ async def test_handle_request_without_run_raises_error():
with pytest.raises(RuntimeError) as excinfo:
await manager.handle_request(scope, receive, send)
assert "Task group is not initialized. Make sure to use run()." in str(
excinfo.value
)
assert "Task group is not initialized. Make sure to use run()." in str(excinfo.value)