mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-22 00:04:21 +01:00
Use 120 characters instead of 88 (#856)
This commit is contained in:
committed by
GitHub
parent
f7265f7b91
commit
543961968c
@@ -116,18 +116,14 @@ def no_expiry_access_token() -> AccessToken:
|
||||
class TestBearerAuthBackend:
|
||||
"""Tests for the BearerAuthBackend class."""
|
||||
|
||||
async def test_no_auth_header(
|
||||
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
):
|
||||
async def test_no_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
|
||||
"""Test authentication with no Authorization header."""
|
||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||
request = Request({"type": "http", "headers": []})
|
||||
result = await backend.authenticate(request)
|
||||
assert result is None
|
||||
|
||||
async def test_non_bearer_auth_header(
|
||||
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
):
|
||||
async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
|
||||
"""Test authentication with non-Bearer Authorization header."""
|
||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||
request = Request(
|
||||
@@ -139,9 +135,7 @@ class TestBearerAuthBackend:
|
||||
result = await backend.authenticate(request)
|
||||
assert result is None
|
||||
|
||||
async def test_invalid_token(
|
||||
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
):
|
||||
async def test_invalid_token(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
|
||||
"""Test authentication with invalid token."""
|
||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||
request = Request(
|
||||
@@ -160,9 +154,7 @@ class TestBearerAuthBackend:
|
||||
):
|
||||
"""Test authentication with expired token."""
|
||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||
add_token_to_provider(
|
||||
mock_oauth_provider, "expired_token", expired_access_token
|
||||
)
|
||||
add_token_to_provider(mock_oauth_provider, "expired_token", expired_access_token)
|
||||
request = Request(
|
||||
{
|
||||
"type": "http",
|
||||
@@ -203,9 +195,7 @@ class TestBearerAuthBackend:
|
||||
):
|
||||
"""Test authentication with token that has no expiry."""
|
||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||
add_token_to_provider(
|
||||
mock_oauth_provider, "no_expiry_token", no_expiry_access_token
|
||||
)
|
||||
add_token_to_provider(mock_oauth_provider, "no_expiry_token", no_expiry_access_token)
|
||||
request = Request(
|
||||
{
|
||||
"type": "http",
|
||||
|
||||
@@ -128,16 +128,12 @@ class TestRegistrationErrorHandling:
|
||||
|
||||
class TestAuthorizeErrorHandling:
|
||||
@pytest.mark.anyio
|
||||
async def test_authorize_error_handling(
|
||||
self, client, oauth_provider, registered_client, pkce_challenge
|
||||
):
|
||||
async def test_authorize_error_handling(self, client, oauth_provider, registered_client, pkce_challenge):
|
||||
# Mock the authorize method to raise an authorize error
|
||||
with unittest.mock.patch.object(
|
||||
oauth_provider,
|
||||
"authorize",
|
||||
side_effect=AuthorizeError(
|
||||
error="access_denied", error_description="The user denied the request"
|
||||
),
|
||||
side_effect=AuthorizeError(error="access_denied", error_description="The user denied the request"),
|
||||
):
|
||||
# Register the client
|
||||
client_id = registered_client["client_id"]
|
||||
@@ -169,9 +165,7 @@ class TestAuthorizeErrorHandling:
|
||||
|
||||
class TestTokenErrorHandling:
|
||||
@pytest.mark.anyio
|
||||
async def test_token_error_handling_auth_code(
|
||||
self, client, oauth_provider, registered_client, pkce_challenge
|
||||
):
|
||||
async def test_token_error_handling_auth_code(self, client, oauth_provider, registered_client, pkce_challenge):
|
||||
# Register the client and get an auth code
|
||||
client_id = registered_client["client_id"]
|
||||
client_secret = registered_client["client_secret"]
|
||||
@@ -224,9 +218,7 @@ class TestTokenErrorHandling:
|
||||
assert data["error_description"] == "The authorization code is invalid"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_token_error_handling_refresh_token(
|
||||
self, client, oauth_provider, registered_client, pkce_challenge
|
||||
):
|
||||
async def test_token_error_handling_refresh_token(self, client, oauth_provider, registered_client, pkce_challenge):
|
||||
# Register the client and get tokens
|
||||
client_id = registered_client["client_id"]
|
||||
client_secret = registered_client["client_secret"]
|
||||
|
||||
@@ -47,9 +47,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
async def register_client(self, client_info: OAuthClientInformationFull):
|
||||
self.clients[client_info.client_id] = client_info
|
||||
|
||||
async def authorize(
|
||||
self, client: OAuthClientInformationFull, params: AuthorizationParams
|
||||
) -> str:
|
||||
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
|
||||
# toy authorize implementation which just immediately generates an authorization
|
||||
# code and completes the redirect
|
||||
code = AuthorizationCode(
|
||||
@@ -63,9 +61,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
)
|
||||
self.auth_codes[code.code] = code
|
||||
|
||||
return construct_redirect_uri(
|
||||
str(params.redirect_uri), code=code.code, state=params.state
|
||||
)
|
||||
return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state)
|
||||
|
||||
async def load_authorization_code(
|
||||
self, client: OAuthClientInformationFull, authorization_code: str
|
||||
@@ -102,9 +98,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
|
||||
refresh_token=refresh_token,
|
||||
)
|
||||
|
||||
async def load_refresh_token(
|
||||
self, client: OAuthClientInformationFull, refresh_token: str
|
||||
) -> RefreshToken | None:
|
||||
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None:
|
||||
old_access_token = self.refresh_tokens.get(refresh_token)
|
||||
if old_access_token is None:
|
||||
return None
|
||||
@@ -224,9 +218,7 @@ def auth_app(mock_oauth_provider):
|
||||
|
||||
@pytest.fixture
|
||||
async def test_client(auth_app):
|
||||
async with httpx.AsyncClient(
|
||||
transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com"
|
||||
) as client:
|
||||
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com") as client:
|
||||
yield client
|
||||
|
||||
|
||||
@@ -261,11 +253,7 @@ async def registered_client(test_client: httpx.AsyncClient, request):
|
||||
def pkce_challenge():
|
||||
"""Create a PKCE challenge with code_verifier and code_challenge."""
|
||||
code_verifier = "some_random_verifier_string"
|
||||
code_challenge = (
|
||||
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
|
||||
.decode()
|
||||
.rstrip("=")
|
||||
)
|
||||
code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().rstrip("=")
|
||||
|
||||
return {"code_verifier": code_verifier, "code_challenge": code_challenge}
|
||||
|
||||
@@ -356,17 +344,13 @@ class TestAuthEndpoints:
|
||||
|
||||
metadata = response.json()
|
||||
assert metadata["issuer"] == "https://auth.example.com/"
|
||||
assert (
|
||||
metadata["authorization_endpoint"] == "https://auth.example.com/authorize"
|
||||
)
|
||||
assert metadata["authorization_endpoint"] == "https://auth.example.com/authorize"
|
||||
assert metadata["token_endpoint"] == "https://auth.example.com/token"
|
||||
assert metadata["registration_endpoint"] == "https://auth.example.com/register"
|
||||
assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke"
|
||||
assert metadata["response_types_supported"] == ["code"]
|
||||
assert metadata["code_challenge_methods_supported"] == ["S256"]
|
||||
assert metadata["token_endpoint_auth_methods_supported"] == [
|
||||
"client_secret_post"
|
||||
]
|
||||
assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post"]
|
||||
assert metadata["grant_types_supported"] == [
|
||||
"authorization_code",
|
||||
"refresh_token",
|
||||
@@ -386,14 +370,10 @@ class TestAuthEndpoints:
|
||||
)
|
||||
error_response = response.json()
|
||||
assert error_response["error"] == "invalid_request"
|
||||
assert (
|
||||
"error_description" in error_response
|
||||
) # Contains validation error messages
|
||||
assert "error_description" in error_response # Contains validation error messages
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_token_invalid_auth_code(
|
||||
self, test_client, registered_client, pkce_challenge
|
||||
):
|
||||
async def test_token_invalid_auth_code(self, test_client, registered_client, pkce_challenge):
|
||||
"""Test token endpoint error - authorization code does not exist."""
|
||||
# Try to use a non-existent authorization code
|
||||
response = await test_client.post(
|
||||
@@ -413,9 +393,7 @@ class TestAuthEndpoints:
|
||||
assert response.status_code == 400
|
||||
error_response = response.json()
|
||||
assert error_response["error"] == "invalid_grant"
|
||||
assert (
|
||||
"authorization code does not exist" in error_response["error_description"]
|
||||
)
|
||||
assert "authorization code does not exist" in error_response["error_description"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_token_expired_auth_code(
|
||||
@@ -458,9 +436,7 @@ class TestAuthEndpoints:
|
||||
assert response.status_code == 400
|
||||
error_response = response.json()
|
||||
assert error_response["error"] == "invalid_grant"
|
||||
assert (
|
||||
"authorization code has expired" in error_response["error_description"]
|
||||
)
|
||||
assert "authorization code has expired" in error_response["error_description"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize(
|
||||
@@ -475,9 +451,7 @@ class TestAuthEndpoints:
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_token_redirect_uri_mismatch(
|
||||
self, test_client, registered_client, auth_code, pkce_challenge
|
||||
):
|
||||
async def test_token_redirect_uri_mismatch(self, test_client, registered_client, auth_code, pkce_challenge):
|
||||
"""Test token endpoint error - redirect URI mismatch."""
|
||||
# Try to use the code with a different redirect URI
|
||||
response = await test_client.post(
|
||||
@@ -498,9 +472,7 @@ class TestAuthEndpoints:
|
||||
assert "redirect_uri did not match" in error_response["error_description"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_token_code_verifier_mismatch(
|
||||
self, test_client, registered_client, auth_code
|
||||
):
|
||||
async def test_token_code_verifier_mismatch(self, test_client, registered_client, auth_code):
|
||||
"""Test token endpoint error - PKCE code verifier mismatch."""
|
||||
# Try to use the code with an incorrect code verifier
|
||||
response = await test_client.post(
|
||||
@@ -569,9 +541,7 @@ class TestAuthEndpoints:
|
||||
|
||||
# Step 2: Time travel forward 4 hours (tokens expire in 1 hour by default)
|
||||
# Mock the time.time() function to return a value 4 hours in the future
|
||||
with unittest.mock.patch(
|
||||
"time.time", return_value=current_time + 14400
|
||||
): # 4 hours = 14400 seconds
|
||||
with unittest.mock.patch("time.time", return_value=current_time + 14400): # 4 hours = 14400 seconds
|
||||
# Try to use the refresh token which should now be considered expired
|
||||
response = await test_client.post(
|
||||
"/token",
|
||||
@@ -590,9 +560,7 @@ class TestAuthEndpoints:
|
||||
assert "refresh token has expired" in error_response["error_description"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_token_invalid_scope(
|
||||
self, test_client, registered_client, auth_code, pkce_challenge
|
||||
):
|
||||
async def test_token_invalid_scope(self, test_client, registered_client, auth_code, pkce_challenge):
|
||||
"""Test token endpoint error - invalid scope in refresh token request."""
|
||||
# Exchange authorization code for tokens
|
||||
token_response = await test_client.post(
|
||||
@@ -628,9 +596,7 @@ class TestAuthEndpoints:
|
||||
assert "cannot request scope" in error_response["error_description"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_registration(
|
||||
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider
|
||||
):
|
||||
async def test_client_registration(self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider):
|
||||
"""Test client registration."""
|
||||
client_metadata = {
|
||||
"redirect_uris": ["https://client.example.com/callback"],
|
||||
@@ -656,9 +622,7 @@ class TestAuthEndpoints:
|
||||
# ) is not None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_registration_missing_required_fields(
|
||||
self, test_client: httpx.AsyncClient
|
||||
):
|
||||
async def test_client_registration_missing_required_fields(self, test_client: httpx.AsyncClient):
|
||||
"""Test client registration with missing required fields."""
|
||||
# Missing redirect_uris which is a required field
|
||||
client_metadata = {
|
||||
@@ -677,9 +641,7 @@ class TestAuthEndpoints:
|
||||
assert error_data["error_description"] == "redirect_uris: Field required"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_registration_invalid_uri(
|
||||
self, test_client: httpx.AsyncClient
|
||||
):
|
||||
async def test_client_registration_invalid_uri(self, test_client: httpx.AsyncClient):
|
||||
"""Test client registration with invalid URIs."""
|
||||
# Invalid redirect_uri format
|
||||
client_metadata = {
|
||||
@@ -696,14 +658,11 @@ class TestAuthEndpoints:
|
||||
assert "error" in error_data
|
||||
assert error_data["error"] == "invalid_client_metadata"
|
||||
assert error_data["error_description"] == (
|
||||
"redirect_uris.0: Input should be a valid URL, "
|
||||
"relative URL without a base"
|
||||
"redirect_uris.0: Input should be a valid URL, " "relative URL without a base"
|
||||
)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_registration_empty_redirect_uris(
|
||||
self, test_client: httpx.AsyncClient
|
||||
):
|
||||
async def test_client_registration_empty_redirect_uris(self, test_client: httpx.AsyncClient):
|
||||
"""Test client registration with empty redirect_uris array."""
|
||||
client_metadata = {
|
||||
"redirect_uris": [], # Empty array
|
||||
@@ -719,8 +678,7 @@ class TestAuthEndpoints:
|
||||
assert "error" in error_data
|
||||
assert error_data["error"] == "invalid_client_metadata"
|
||||
assert (
|
||||
error_data["error_description"]
|
||||
== "redirect_uris: List should have at least 1 item after validation, not 0"
|
||||
error_data["error_description"] == "redirect_uris: List should have at least 1 item after validation, not 0"
|
||||
)
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -875,12 +833,7 @@ class TestAuthEndpoints:
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify that the token was revoked
|
||||
assert (
|
||||
await mock_oauth_provider.load_access_token(
|
||||
new_token_response["access_token"]
|
||||
)
|
||||
is None
|
||||
)
|
||||
assert await mock_oauth_provider.load_access_token(new_token_response["access_token"]) is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_revoke_invalid_token(self, test_client, registered_client):
|
||||
@@ -913,9 +866,7 @@ class TestAuthEndpoints:
|
||||
assert "token_type_hint" in error_response["error_description"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_registration_disallowed_scopes(
|
||||
self, test_client: httpx.AsyncClient
|
||||
):
|
||||
async def test_client_registration_disallowed_scopes(self, test_client: httpx.AsyncClient):
|
||||
"""Test client registration with scopes that are not allowed."""
|
||||
client_metadata = {
|
||||
"redirect_uris": ["https://client.example.com/callback"],
|
||||
@@ -955,18 +906,14 @@ class TestAuthEndpoints:
|
||||
assert client_info["scope"] == "read write"
|
||||
|
||||
# Retrieve the client from the store to verify default scopes
|
||||
registered_client = await mock_oauth_provider.get_client(
|
||||
client_info["client_id"]
|
||||
)
|
||||
registered_client = await mock_oauth_provider.get_client(client_info["client_id"])
|
||||
assert registered_client is not None
|
||||
|
||||
# Check that default scopes were applied
|
||||
assert registered_client.scope == "read write"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_client_registration_invalid_grant_type(
|
||||
self, test_client: httpx.AsyncClient
|
||||
):
|
||||
async def test_client_registration_invalid_grant_type(self, test_client: httpx.AsyncClient):
|
||||
client_metadata = {
|
||||
"redirect_uris": ["https://client.example.com/callback"],
|
||||
"client_name": "Test Client",
|
||||
@@ -981,19 +928,14 @@ class TestAuthEndpoints:
|
||||
error_data = response.json()
|
||||
assert "error" in error_data
|
||||
assert error_data["error"] == "invalid_client_metadata"
|
||||
assert (
|
||||
error_data["error_description"]
|
||||
== "grant_types must be authorization_code and refresh_token"
|
||||
)
|
||||
assert error_data["error_description"] == "grant_types must be authorization_code and refresh_token"
|
||||
|
||||
|
||||
class TestAuthorizeEndpointErrors:
|
||||
"""Test error handling in the OAuth authorization endpoint."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_authorize_missing_client_id(
|
||||
self, test_client: httpx.AsyncClient, pkce_challenge
|
||||
):
|
||||
async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, pkce_challenge):
|
||||
"""Test authorization endpoint with missing client_id.
|
||||
|
||||
According to the OAuth2.0 spec, if client_id is missing, the server should
|
||||
@@ -1017,9 +959,7 @@ class TestAuthorizeEndpointErrors:
|
||||
assert "client_id" in response.text.lower()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_authorize_invalid_client_id(
|
||||
self, test_client: httpx.AsyncClient, pkce_challenge
|
||||
):
|
||||
async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, pkce_challenge):
|
||||
"""Test authorization endpoint with invalid client_id.
|
||||
|
||||
According to the OAuth2.0 spec, if client_id is invalid, the server should
|
||||
@@ -1202,9 +1142,7 @@ class TestAuthorizeEndpointErrors:
|
||||
assert query_params["state"][0] == "test_state"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_authorize_missing_pkce_challenge(
|
||||
self, test_client: httpx.AsyncClient, registered_client
|
||||
):
|
||||
async def test_authorize_missing_pkce_challenge(self, test_client: httpx.AsyncClient, registered_client):
|
||||
"""Test authorization endpoint with missing PKCE code_challenge.
|
||||
|
||||
Missing PKCE parameters should result in invalid_request error.
|
||||
@@ -1233,9 +1171,7 @@ class TestAuthorizeEndpointErrors:
|
||||
assert query_params["state"][0] == "test_state"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_authorize_invalid_scope(
|
||||
self, test_client: httpx.AsyncClient, registered_client, pkce_challenge
|
||||
):
|
||||
async def test_authorize_invalid_scope(self, test_client: httpx.AsyncClient, registered_client, pkce_challenge):
|
||||
"""Test authorization endpoint with invalid scope.
|
||||
|
||||
Invalid scope should redirect with invalid_scope error.
|
||||
|
||||
@@ -18,9 +18,7 @@ class TestRenderPrompt:
|
||||
return "Hello, world!"
|
||||
|
||||
prompt = Prompt.from_function(fn)
|
||||
assert await prompt.render() == [
|
||||
UserMessage(content=TextContent(type="text", text="Hello, world!"))
|
||||
]
|
||||
assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_async_fn(self):
|
||||
@@ -28,9 +26,7 @@ class TestRenderPrompt:
|
||||
return "Hello, world!"
|
||||
|
||||
prompt = Prompt.from_function(fn)
|
||||
assert await prompt.render() == [
|
||||
UserMessage(content=TextContent(type="text", text="Hello, world!"))
|
||||
]
|
||||
assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_fn_with_args(self):
|
||||
@@ -39,11 +35,7 @@ class TestRenderPrompt:
|
||||
|
||||
prompt = Prompt.from_function(fn)
|
||||
assert await prompt.render(arguments={"name": "World"}) == [
|
||||
UserMessage(
|
||||
content=TextContent(
|
||||
type="text", text="Hello, World! You're 30 years old."
|
||||
)
|
||||
)
|
||||
UserMessage(content=TextContent(type="text", text="Hello, World! You're 30 years old."))
|
||||
]
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -61,21 +53,15 @@ class TestRenderPrompt:
|
||||
return UserMessage(content="Hello, world!")
|
||||
|
||||
prompt = Prompt.from_function(fn)
|
||||
assert await prompt.render() == [
|
||||
UserMessage(content=TextContent(type="text", text="Hello, world!"))
|
||||
]
|
||||
assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_fn_returns_assistant_message(self):
|
||||
async def fn() -> AssistantMessage:
|
||||
return AssistantMessage(
|
||||
content=TextContent(type="text", text="Hello, world!")
|
||||
)
|
||||
return AssistantMessage(content=TextContent(type="text", text="Hello, world!"))
|
||||
|
||||
prompt = Prompt.from_function(fn)
|
||||
assert await prompt.render() == [
|
||||
AssistantMessage(content=TextContent(type="text", text="Hello, world!"))
|
||||
]
|
||||
assert await prompt.render() == [AssistantMessage(content=TextContent(type="text", text="Hello, world!"))]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_fn_returns_multiple_messages(self):
|
||||
@@ -156,9 +142,7 @@ class TestRenderPrompt:
|
||||
|
||||
prompt = Prompt.from_function(fn)
|
||||
assert await prompt.render() == [
|
||||
UserMessage(
|
||||
content=TextContent(type="text", text="Please analyze this file:")
|
||||
),
|
||||
UserMessage(content=TextContent(type="text", text="Please analyze this file:")),
|
||||
UserMessage(
|
||||
content=EmbeddedResource(
|
||||
type="resource",
|
||||
@@ -169,9 +153,7 @@ class TestRenderPrompt:
|
||||
),
|
||||
)
|
||||
),
|
||||
AssistantMessage(
|
||||
content=TextContent(type="text", text="I'll help analyze that file.")
|
||||
),
|
||||
AssistantMessage(content=TextContent(type="text", text="I'll help analyze that file.")),
|
||||
]
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
@@ -72,9 +72,7 @@ class TestPromptManager:
|
||||
prompt = Prompt.from_function(fn)
|
||||
manager.add_prompt(prompt)
|
||||
messages = await manager.render_prompt("fn")
|
||||
assert messages == [
|
||||
UserMessage(content=TextContent(type="text", text="Hello, world!"))
|
||||
]
|
||||
assert messages == [UserMessage(content=TextContent(type="text", text="Hello, world!"))]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_render_prompt_with_args(self):
|
||||
@@ -87,9 +85,7 @@ class TestPromptManager:
|
||||
prompt = Prompt.from_function(fn)
|
||||
manager.add_prompt(prompt)
|
||||
messages = await manager.render_prompt("fn", arguments={"name": "World"})
|
||||
assert messages == [
|
||||
UserMessage(content=TextContent(type="text", text="Hello, World!"))
|
||||
]
|
||||
assert messages == [UserMessage(content=TextContent(type="text", text="Hello, World!"))]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_render_unknown_prompt(self):
|
||||
|
||||
@@ -100,9 +100,7 @@ class TestFileResource:
|
||||
with pytest.raises(ValueError, match="Error reading file"):
|
||||
await resource.read()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.name == "nt", reason="File permissions behave differently on Windows"
|
||||
)
|
||||
@pytest.mark.skipif(os.name == "nt", reason="File permissions behave differently on Windows")
|
||||
@pytest.mark.anyio
|
||||
async def test_permission_error(self, temp_file: Path):
|
||||
"""Test reading a file without permissions."""
|
||||
|
||||
@@ -28,9 +28,7 @@ def complex_arguments_fn(
|
||||
# list[str] | str is an interesting case because if it comes in as JSON like
|
||||
# "[\"a\", \"b\"]" then it will be naively parsed as a string.
|
||||
list_str_or_str: list[str] | str,
|
||||
an_int_annotated_with_field: Annotated[
|
||||
int, Field(description="An int with a field")
|
||||
],
|
||||
an_int_annotated_with_field: Annotated[int, Field(description="An int with a field")],
|
||||
an_int_annotated_with_field_and_others: Annotated[
|
||||
int,
|
||||
str, # Should be ignored, really
|
||||
@@ -42,9 +40,7 @@ def complex_arguments_fn(
|
||||
"123",
|
||||
456,
|
||||
],
|
||||
field_with_default_via_field_annotation_before_nondefault_arg: Annotated[
|
||||
int, Field(1)
|
||||
],
|
||||
field_with_default_via_field_annotation_before_nondefault_arg: Annotated[int, Field(1)],
|
||||
unannotated,
|
||||
my_model_a: SomeInputModelA,
|
||||
my_model_a_forward_ref: "SomeInputModelA",
|
||||
@@ -179,9 +175,7 @@ def test_str_vs_list_str():
|
||||
def test_skip_names():
|
||||
"""Test that skipped parameters are not included in the model"""
|
||||
|
||||
def func_with_many_params(
|
||||
keep_this: int, skip_this: str, also_keep: float, also_skip: bool
|
||||
):
|
||||
def func_with_many_params(keep_this: int, skip_this: str, also_keep: float, also_skip: bool):
|
||||
return keep_this, skip_this, also_keep, also_skip
|
||||
|
||||
# Skip some parameters
|
||||
|
||||
@@ -130,11 +130,7 @@ def make_everything_fastmcp() -> FastMCP:
|
||||
|
||||
# Request sampling from the client
|
||||
result = await ctx.session.create_message(
|
||||
messages=[
|
||||
SamplingMessage(
|
||||
role="user", content=TextContent(type="text", text=prompt)
|
||||
)
|
||||
],
|
||||
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=prompt))],
|
||||
max_tokens=100,
|
||||
temperature=0.7,
|
||||
)
|
||||
@@ -278,11 +274,7 @@ def make_fastmcp_stateless_http_app():
|
||||
def run_server(server_port: int) -> None:
|
||||
"""Run the server."""
|
||||
_, app = make_fastmcp_app()
|
||||
server = uvicorn.Server(
|
||||
config=uvicorn.Config(
|
||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||
)
|
||||
)
|
||||
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||
print(f"Starting server on port {server_port}")
|
||||
server.run()
|
||||
|
||||
@@ -290,11 +282,7 @@ def run_server(server_port: int) -> None:
|
||||
def run_everything_legacy_sse_http_server(server_port: int) -> None:
|
||||
"""Run the comprehensive server with all features."""
|
||||
_, app = make_everything_fastmcp_app()
|
||||
server = uvicorn.Server(
|
||||
config=uvicorn.Config(
|
||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||
)
|
||||
)
|
||||
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||
print(f"Starting comprehensive server on port {server_port}")
|
||||
server.run()
|
||||
|
||||
@@ -302,11 +290,7 @@ def run_everything_legacy_sse_http_server(server_port: int) -> None:
|
||||
def run_streamable_http_server(server_port: int) -> None:
|
||||
"""Run the StreamableHTTP server."""
|
||||
_, app = make_fastmcp_streamable_http_app()
|
||||
server = uvicorn.Server(
|
||||
config=uvicorn.Config(
|
||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||
)
|
||||
)
|
||||
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||
print(f"Starting StreamableHTTP server on port {server_port}")
|
||||
server.run()
|
||||
|
||||
@@ -314,11 +298,7 @@ def run_streamable_http_server(server_port: int) -> None:
|
||||
def run_everything_server(server_port: int) -> None:
|
||||
"""Run the comprehensive StreamableHTTP server with all features."""
|
||||
_, app = make_everything_fastmcp_streamable_http_app()
|
||||
server = uvicorn.Server(
|
||||
config=uvicorn.Config(
|
||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||
)
|
||||
)
|
||||
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||
print(f"Starting comprehensive StreamableHTTP server on port {server_port}")
|
||||
server.run()
|
||||
|
||||
@@ -326,11 +306,7 @@ def run_everything_server(server_port: int) -> None:
|
||||
def run_stateless_http_server(server_port: int) -> None:
|
||||
"""Run the stateless StreamableHTTP server."""
|
||||
_, app = make_fastmcp_stateless_http_app()
|
||||
server = uvicorn.Server(
|
||||
config=uvicorn.Config(
|
||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
||||
)
|
||||
)
|
||||
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||
print(f"Starting stateless StreamableHTTP server on port {server_port}")
|
||||
server.run()
|
||||
|
||||
@@ -369,9 +345,7 @@ def server(server_port: int) -> Generator[None, None, None]:
|
||||
@pytest.fixture()
|
||||
def streamable_http_server(http_server_port: int) -> Generator[None, None, None]:
|
||||
"""Start the StreamableHTTP server in a separate process."""
|
||||
proc = multiprocessing.Process(
|
||||
target=run_streamable_http_server, args=(http_server_port,), daemon=True
|
||||
)
|
||||
proc = multiprocessing.Process(target=run_streamable_http_server, args=(http_server_port,), daemon=True)
|
||||
print("Starting StreamableHTTP server process")
|
||||
proc.start()
|
||||
|
||||
@@ -388,9 +362,7 @@ def streamable_http_server(http_server_port: int) -> Generator[None, None, None]
|
||||
time.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"StreamableHTTP server failed to start after {max_attempts} attempts"
|
||||
)
|
||||
raise RuntimeError(f"StreamableHTTP server failed to start after {max_attempts} attempts")
|
||||
|
||||
yield
|
||||
|
||||
@@ -427,9 +399,7 @@ def stateless_http_server(
|
||||
time.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Stateless server failed to start after {max_attempts} attempts"
|
||||
)
|
||||
raise RuntimeError(f"Stateless server failed to start after {max_attempts} attempts")
|
||||
|
||||
yield
|
||||
|
||||
@@ -459,9 +429,7 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_fastmcp_streamable_http(
|
||||
streamable_http_server: None, http_server_url: str
|
||||
) -> None:
|
||||
async def test_fastmcp_streamable_http(streamable_http_server: None, http_server_url: str) -> None:
|
||||
"""Test that FastMCP works with StreamableHTTP transport."""
|
||||
# Connect to the server using StreamableHTTP
|
||||
async with streamablehttp_client(http_server_url + "/mcp") as (
|
||||
@@ -484,9 +452,7 @@ async def test_fastmcp_streamable_http(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_fastmcp_stateless_streamable_http(
|
||||
stateless_http_server: None, stateless_http_server_url: str
|
||||
) -> None:
|
||||
async def test_fastmcp_stateless_streamable_http(stateless_http_server: None, stateless_http_server_url: str) -> None:
|
||||
"""Test that FastMCP works with stateless StreamableHTTP transport."""
|
||||
# Connect to the server using StreamableHTTP
|
||||
async with streamablehttp_client(stateless_http_server_url + "/mcp") as (
|
||||
@@ -562,9 +528,7 @@ def everything_server(everything_server_port: int) -> Generator[None, None, None
|
||||
time.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Comprehensive server failed to start after {max_attempts} attempts"
|
||||
)
|
||||
raise RuntimeError(f"Comprehensive server failed to start after {max_attempts} attempts")
|
||||
|
||||
yield
|
||||
|
||||
@@ -601,10 +565,7 @@ def everything_streamable_http_server(
|
||||
time.sleep(0.1)
|
||||
attempt += 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Comprehensive StreamableHTTP server failed to start after "
|
||||
f"{max_attempts} attempts"
|
||||
)
|
||||
raise RuntimeError(f"Comprehensive StreamableHTTP server failed to start after " f"{max_attempts} attempts")
|
||||
|
||||
yield
|
||||
|
||||
@@ -648,9 +609,7 @@ class NotificationCollector:
|
||||
await self.handle_tool_list_changed(message.root.params)
|
||||
|
||||
|
||||
async def call_all_mcp_features(
|
||||
session: ClientSession, collector: NotificationCollector
|
||||
) -> None:
|
||||
async def call_all_mcp_features(session: ClientSession, collector: NotificationCollector) -> None:
|
||||
"""
|
||||
Test all MCP features using the provided session.
|
||||
|
||||
@@ -680,9 +639,7 @@ async def call_all_mcp_features(
|
||||
# Test progress callback functionality
|
||||
progress_updates = []
|
||||
|
||||
async def progress_callback(
|
||||
progress: float, total: float | None, message: str | None
|
||||
) -> None:
|
||||
async def progress_callback(progress: float, total: float | None, message: str | None) -> None:
|
||||
"""Collect progress updates for testing (async version)."""
|
||||
progress_updates.append((progress, total, message))
|
||||
print(f"Progress: {progress}/{total} - {message}")
|
||||
@@ -726,19 +683,12 @@ async def call_all_mcp_features(
|
||||
|
||||
# Verify we received log messages from the sampling tool
|
||||
assert len(collector.log_messages) > 0
|
||||
assert any(
|
||||
"Requesting sampling for prompt" in msg.data for msg in collector.log_messages
|
||||
)
|
||||
assert any(
|
||||
"Received sampling result from model" in msg.data
|
||||
for msg in collector.log_messages
|
||||
)
|
||||
assert any("Requesting sampling for prompt" in msg.data for msg in collector.log_messages)
|
||||
assert any("Received sampling result from model" in msg.data for msg in collector.log_messages)
|
||||
|
||||
# 4. Test notification tool
|
||||
notification_message = "test_notifications"
|
||||
notification_result = await session.call_tool(
|
||||
"notification_tool", {"message": notification_message}
|
||||
)
|
||||
notification_result = await session.call_tool("notification_tool", {"message": notification_message})
|
||||
assert len(notification_result.content) == 1
|
||||
assert isinstance(notification_result.content[0], TextContent)
|
||||
assert "Sent notifications and logs" in notification_result.content[0].text
|
||||
@@ -773,36 +723,24 @@ async def call_all_mcp_features(
|
||||
|
||||
# 2. Dynamic resource
|
||||
resource_category = "test"
|
||||
dynamic_content = await session.read_resource(
|
||||
AnyUrl(f"resource://dynamic/{resource_category}")
|
||||
)
|
||||
dynamic_content = await session.read_resource(AnyUrl(f"resource://dynamic/{resource_category}"))
|
||||
assert isinstance(dynamic_content, ReadResourceResult)
|
||||
assert len(dynamic_content.contents) == 1
|
||||
assert isinstance(dynamic_content.contents[0], TextResourceContents)
|
||||
assert (
|
||||
f"Dynamic resource content for category: {resource_category}"
|
||||
in dynamic_content.contents[0].text
|
||||
)
|
||||
assert f"Dynamic resource content for category: {resource_category}" in dynamic_content.contents[0].text
|
||||
|
||||
# 3. Template resource
|
||||
resource_id = "456"
|
||||
template_content = await session.read_resource(
|
||||
AnyUrl(f"resource://template/{resource_id}/data")
|
||||
)
|
||||
template_content = await session.read_resource(AnyUrl(f"resource://template/{resource_id}/data"))
|
||||
assert isinstance(template_content, ReadResourceResult)
|
||||
assert len(template_content.contents) == 1
|
||||
assert isinstance(template_content.contents[0], TextResourceContents)
|
||||
assert (
|
||||
f"Template resource data for ID: {resource_id}"
|
||||
in template_content.contents[0].text
|
||||
)
|
||||
assert f"Template resource data for ID: {resource_id}" in template_content.contents[0].text
|
||||
|
||||
# Test prompts
|
||||
# 1. Simple prompt
|
||||
prompts = await session.list_prompts()
|
||||
simple_prompt = next(
|
||||
(p for p in prompts.prompts if p.name == "simple_prompt"), None
|
||||
)
|
||||
simple_prompt = next((p for p in prompts.prompts if p.name == "simple_prompt"), None)
|
||||
assert simple_prompt is not None
|
||||
|
||||
prompt_topic = "AI"
|
||||
@@ -812,16 +750,12 @@ async def call_all_mcp_features(
|
||||
# The actual message structure depends on the prompt implementation
|
||||
|
||||
# 2. Complex prompt
|
||||
complex_prompt = next(
|
||||
(p for p in prompts.prompts if p.name == "complex_prompt"), None
|
||||
)
|
||||
complex_prompt = next((p for p in prompts.prompts if p.name == "complex_prompt"), None)
|
||||
assert complex_prompt is not None
|
||||
|
||||
query = "What is AI?"
|
||||
context = "technical"
|
||||
complex_result = await session.get_prompt(
|
||||
"complex_prompt", {"user_query": query, "context": context}
|
||||
)
|
||||
complex_result = await session.get_prompt("complex_prompt", {"user_query": query, "context": context})
|
||||
assert isinstance(complex_result, GetPromptResult)
|
||||
assert len(complex_result.messages) >= 1
|
||||
|
||||
@@ -837,9 +771,7 @@ async def call_all_mcp_features(
|
||||
print(f"Received headers: {headers_data}")
|
||||
|
||||
# Test 6: Call tool that returns full context
|
||||
context_result = await session.call_tool(
|
||||
"echo_context", {"custom_request_id": "test-123"}
|
||||
)
|
||||
context_result = await session.call_tool("echo_context", {"custom_request_id": "test-123"})
|
||||
assert len(context_result.content) == 1
|
||||
assert isinstance(context_result.content[0], TextContent)
|
||||
|
||||
@@ -871,9 +803,7 @@ async def sampling_callback(
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_fastmcp_all_features_sse(
|
||||
everything_server: None, everything_server_url: str
|
||||
) -> None:
|
||||
async def test_fastmcp_all_features_sse(everything_server: None, everything_server_url: str) -> None:
|
||||
"""Test all MCP features work correctly with SSE transport."""
|
||||
|
||||
# Create notification collector
|
||||
|
||||
@@ -59,9 +59,7 @@ class TestServer:
|
||||
"""Test SSE app creation with different mount paths."""
|
||||
# Test with default mount path
|
||||
mcp = FastMCP()
|
||||
with patch.object(
|
||||
mcp, "_normalize_path", return_value="/messages/"
|
||||
) as mock_normalize:
|
||||
with patch.object(mcp, "_normalize_path", return_value="/messages/") as mock_normalize:
|
||||
mcp.sse_app()
|
||||
# Verify _normalize_path was called with correct args
|
||||
mock_normalize.assert_called_once_with("/", "/messages/")
|
||||
@@ -69,18 +67,14 @@ class TestServer:
|
||||
# Test with custom mount path in settings
|
||||
mcp = FastMCP()
|
||||
mcp.settings.mount_path = "/custom"
|
||||
with patch.object(
|
||||
mcp, "_normalize_path", return_value="/custom/messages/"
|
||||
) as mock_normalize:
|
||||
with patch.object(mcp, "_normalize_path", return_value="/custom/messages/") as mock_normalize:
|
||||
mcp.sse_app()
|
||||
# Verify _normalize_path was called with correct args
|
||||
mock_normalize.assert_called_once_with("/custom", "/messages/")
|
||||
|
||||
# Test with mount_path parameter
|
||||
mcp = FastMCP()
|
||||
with patch.object(
|
||||
mcp, "_normalize_path", return_value="/param/messages/"
|
||||
) as mock_normalize:
|
||||
with patch.object(mcp, "_normalize_path", return_value="/param/messages/") as mock_normalize:
|
||||
mcp.sse_app(mount_path="/param")
|
||||
# Verify _normalize_path was called with correct args
|
||||
mock_normalize.assert_called_once_with("/param", "/messages/")
|
||||
@@ -103,9 +97,7 @@ class TestServer:
|
||||
|
||||
# Verify path values
|
||||
assert sse_routes[0].path == "/sse", "SSE route path should be /sse"
|
||||
assert (
|
||||
mount_routes[0].path == "/messages"
|
||||
), "Mount route path should be /messages"
|
||||
assert mount_routes[0].path == "/messages", "Mount route path should be /messages"
|
||||
|
||||
# Test with mount path as parameter
|
||||
mcp = FastMCP()
|
||||
@@ -121,20 +113,14 @@ class TestServer:
|
||||
|
||||
# Verify path values
|
||||
assert sse_routes[0].path == "/sse", "SSE route path should be /sse"
|
||||
assert (
|
||||
mount_routes[0].path == "/messages"
|
||||
), "Mount route path should be /messages"
|
||||
assert mount_routes[0].path == "/messages", "Mount route path should be /messages"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_non_ascii_description(self):
|
||||
"""Test that FastMCP handles non-ASCII characters in descriptions correctly"""
|
||||
mcp = FastMCP()
|
||||
|
||||
@mcp.tool(
|
||||
description=(
|
||||
"🌟 This tool uses emojis and UTF-8 characters: á é í ó ú ñ 漢字 🎉"
|
||||
)
|
||||
)
|
||||
@mcp.tool(description=("🌟 This tool uses emojis and UTF-8 characters: á é í ó ú ñ 漢字 🎉"))
|
||||
def hello_world(name: str = "世界") -> str:
|
||||
return f"¡Hola, {name}! 👋"
|
||||
|
||||
@@ -187,9 +173,7 @@ class TestServer:
|
||||
async def test_add_resource_decorator_incorrect_usage(self):
|
||||
mcp = FastMCP()
|
||||
|
||||
with pytest.raises(
|
||||
TypeError, match="The @resource decorator was used incorrectly"
|
||||
):
|
||||
with pytest.raises(TypeError, match="The @resource decorator was used incorrectly"):
|
||||
|
||||
@mcp.resource # Missing parentheses #type: ignore
|
||||
def get_data(x: str) -> str:
|
||||
@@ -373,9 +357,7 @@ class TestServerResources:
|
||||
def get_text():
|
||||
return "Hello, world!"
|
||||
|
||||
resource = FunctionResource(
|
||||
uri=AnyUrl("resource://test"), name="test", fn=get_text
|
||||
)
|
||||
resource = FunctionResource(uri=AnyUrl("resource://test"), name="test", fn=get_text)
|
||||
mcp.add_resource(resource)
|
||||
|
||||
async with client_session(mcp._mcp_server) as client:
|
||||
@@ -411,9 +393,7 @@ class TestServerResources:
|
||||
text_file = tmp_path / "test.txt"
|
||||
text_file.write_text("Hello from file!")
|
||||
|
||||
resource = FileResource(
|
||||
uri=AnyUrl("file://test.txt"), name="test.txt", path=text_file
|
||||
)
|
||||
resource = FileResource(uri=AnyUrl("file://test.txt"), name="test.txt", path=text_file)
|
||||
mcp.add_resource(resource)
|
||||
|
||||
async with client_session(mcp._mcp_server) as client:
|
||||
@@ -440,10 +420,7 @@ class TestServerResources:
|
||||
async with client_session(mcp._mcp_server) as client:
|
||||
result = await client.read_resource(AnyUrl("file://test.bin"))
|
||||
assert isinstance(result.contents[0], BlobResourceContents)
|
||||
assert (
|
||||
result.contents[0].blob
|
||||
== base64.b64encode(b"Binary file data").decode()
|
||||
)
|
||||
assert result.contents[0].blob == base64.b64encode(b"Binary file data").decode()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_function_resource(self):
|
||||
@@ -532,9 +509,7 @@ class TestServerResourceTemplates:
|
||||
return f"Data for {org}/{repo}"
|
||||
|
||||
async with client_session(mcp._mcp_server) as client:
|
||||
result = await client.read_resource(
|
||||
AnyUrl("resource://cursor/fastmcp/data")
|
||||
)
|
||||
result = await client.read_resource(AnyUrl("resource://cursor/fastmcp/data"))
|
||||
assert isinstance(result.contents[0], TextResourceContents)
|
||||
assert result.contents[0].text == "Data for cursor/fastmcp"
|
||||
|
||||
|
||||
@@ -147,9 +147,7 @@ class TestAddTools:
|
||||
|
||||
def test_add_lambda_with_no_name(self):
|
||||
manager = ToolManager()
|
||||
with pytest.raises(
|
||||
ValueError, match="You must provide a name for lambda functions"
|
||||
):
|
||||
with pytest.raises(ValueError, match="You must provide a name for lambda functions"):
|
||||
manager.add_tool(lambda x: x)
|
||||
|
||||
def test_warn_on_duplicate_tools(self, caplog):
|
||||
@@ -346,9 +344,7 @@ class TestContextHandling:
|
||||
tool = manager.add_tool(tool_without_context)
|
||||
assert tool.context_kwarg is None
|
||||
|
||||
def tool_with_parametrized_context(
|
||||
x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT]
|
||||
) -> str:
|
||||
def tool_with_parametrized_context(x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT]) -> str:
|
||||
return str(x)
|
||||
|
||||
tool = manager.add_tool(tool_with_parametrized_context)
|
||||
|
||||
@@ -10,13 +10,7 @@ from mcp.server.models import InitializationOptions
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.session import RequestResponder
|
||||
from mcp.types import (
|
||||
ClientResult,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
Tool,
|
||||
ToolAnnotations,
|
||||
)
|
||||
from mcp.types import ClientResult, ServerNotification, ServerRequest, Tool, ToolAnnotations
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -45,18 +39,12 @@ async def test_lowlevel_server_tool_annotations():
|
||||
)
|
||||
]
|
||||
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](10)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](10)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10)
|
||||
|
||||
# Message handler for client
|
||||
async def message_handler(
|
||||
message: RequestResponder[ServerRequest, ClientResult]
|
||||
| ServerNotification
|
||||
| Exception,
|
||||
message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
@@ -56,11 +56,7 @@ async def test_read_resource_binary(temp_file: Path):
|
||||
|
||||
@server.read_resource()
|
||||
async def read_resource(uri: AnyUrl) -> Iterable[ReadResourceContents]:
|
||||
return [
|
||||
ReadResourceContents(
|
||||
content=b"Hello World", mime_type="application/octet-stream"
|
||||
)
|
||||
]
|
||||
return [ReadResourceContents(content=b"Hello World", mime_type="application/octet-stream")]
|
||||
|
||||
# Get the handler directly from the server
|
||||
handler = server.request_handlers[types.ReadResourceRequest]
|
||||
|
||||
@@ -20,18 +20,12 @@ from mcp.types import (
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_server_session_initialize():
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
|
||||
# Create a message handler to catch exceptions
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
||||
| types.ServerNotification
|
||||
| Exception,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
@@ -54,9 +48,7 @@ async def test_server_session_initialize():
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
if isinstance(message, ClientNotification) and isinstance(
|
||||
message.root, InitializedNotification
|
||||
):
|
||||
if isinstance(message, ClientNotification) and isinstance(message.root, InitializedNotification):
|
||||
received_initialized = True
|
||||
return
|
||||
|
||||
@@ -111,12 +103,8 @@ async def test_server_capabilities():
|
||||
@pytest.mark.anyio
|
||||
async def test_server_session_initialize_with_older_protocol_version():
|
||||
"""Test that server accepts and responds with older protocol (2024-11-05)."""
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](1)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
SessionMessage | Exception
|
||||
](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
|
||||
|
||||
received_initialized = False
|
||||
received_protocol_version = None
|
||||
@@ -137,9 +125,7 @@ async def test_server_session_initialize_with_older_protocol_version():
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
if isinstance(message, types.ClientNotification) and isinstance(
|
||||
message.root, InitializedNotification
|
||||
):
|
||||
if isinstance(message, types.ClientNotification) and isinstance(message.root, InitializedNotification):
|
||||
received_initialized = True
|
||||
return
|
||||
|
||||
@@ -157,9 +143,7 @@ async def test_server_session_initialize_with_older_protocol_version():
|
||||
params=types.InitializeRequestParams(
|
||||
protocolVersion="2024-11-05",
|
||||
capabilities=types.ClientCapabilities(),
|
||||
clientInfo=types.Implementation(
|
||||
name="test-client", version="1.0.0"
|
||||
),
|
||||
clientInfo=types.Implementation(name="test-client", version="1.0.0"),
|
||||
).model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -22,9 +22,10 @@ async def test_stdio_server():
|
||||
stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n")
|
||||
stdin.seek(0)
|
||||
|
||||
async with stdio_server(
|
||||
stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)
|
||||
) as (read_stream, write_stream):
|
||||
async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
):
|
||||
received_messages = []
|
||||
async with read_stream:
|
||||
async for message in read_stream:
|
||||
@@ -36,12 +37,8 @@ async def test_stdio_server():
|
||||
|
||||
# Verify received messages
|
||||
assert len(received_messages) == 2
|
||||
assert received_messages[0] == JSONRPCMessage(
|
||||
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
|
||||
)
|
||||
assert received_messages[1] == JSONRPCMessage(
|
||||
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
|
||||
)
|
||||
assert received_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))
|
||||
assert received_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}))
|
||||
|
||||
# Test sending responses from the server
|
||||
responses = [
|
||||
@@ -58,13 +55,7 @@ async def test_stdio_server():
|
||||
output_lines = stdout.readlines()
|
||||
assert len(output_lines) == 2
|
||||
|
||||
received_responses = [
|
||||
JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines
|
||||
]
|
||||
received_responses = [JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines]
|
||||
assert len(received_responses) == 2
|
||||
assert received_responses[0] == JSONRPCMessage(
|
||||
root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")
|
||||
)
|
||||
assert received_responses[1] == JSONRPCMessage(
|
||||
root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})
|
||||
)
|
||||
assert received_responses[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping"))
|
||||
assert received_responses[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}))
|
||||
|
||||
@@ -22,10 +22,7 @@ async def test_run_can_only_be_called_once():
|
||||
async with manager.run():
|
||||
pass
|
||||
|
||||
assert (
|
||||
"StreamableHTTPSessionManager .run() can only be called once per instance"
|
||||
in str(excinfo.value)
|
||||
)
|
||||
assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -51,10 +48,7 @@ async def test_run_prevents_concurrent_calls():
|
||||
|
||||
# One should succeed, one should fail
|
||||
assert len(errors) == 1
|
||||
assert (
|
||||
"StreamableHTTPSessionManager .run() can only be called once per instance"
|
||||
in str(errors[0])
|
||||
)
|
||||
assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(errors[0])
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -76,6 +70,4 @@ async def test_handle_request_without_run_raises_error():
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
await manager.handle_request(scope, receive, send)
|
||||
|
||||
assert "Task group is not initialized. Make sure to use run()." in str(
|
||||
excinfo.value
|
||||
)
|
||||
assert "Task group is not initialized. Make sure to use run()." in str(excinfo.value)
|
||||
|
||||
Reference in New Issue
Block a user