mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Use 120 characters instead of 88 (#856)
This commit is contained in:
committed by
GitHub
parent
f7265f7b91
commit
543961968c
@@ -2,7 +2,4 @@ from pydantic import ValidationError
|
||||
|
||||
|
||||
def stringify_pydantic_error(validation_error: ValidationError) -> str:
|
||||
return "\n".join(
|
||||
f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}"
|
||||
for e in validation_error.errors()
|
||||
)
|
||||
return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors())
|
||||
|
||||
@@ -7,9 +7,7 @@ from starlette.datastructures import FormData, QueryParams
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import RedirectResponse, Response
|
||||
|
||||
from mcp.server.auth.errors import (
|
||||
stringify_pydantic_error,
|
||||
)
|
||||
from mcp.server.auth.errors import stringify_pydantic_error
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.provider import (
|
||||
AuthorizationErrorCode,
|
||||
@@ -18,10 +16,7 @@ from mcp.server.auth.provider import (
|
||||
OAuthAuthorizationServerProvider,
|
||||
construct_redirect_uri,
|
||||
)
|
||||
from mcp.shared.auth import (
|
||||
InvalidRedirectUriError,
|
||||
InvalidScopeError,
|
||||
)
|
||||
from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,23 +24,16 @@ logger = logging.getLogger(__name__)
|
||||
class AuthorizationRequest(BaseModel):
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
|
||||
client_id: str = Field(..., description="The client ID")
|
||||
redirect_uri: AnyUrl | None = Field(
|
||||
None, description="URL to redirect to after authorization"
|
||||
)
|
||||
redirect_uri: AnyUrl | None = Field(None, description="URL to redirect to after authorization")
|
||||
|
||||
# see OAuthClientMetadata; we only support `code`
|
||||
response_type: Literal["code"] = Field(
|
||||
..., description="Must be 'code' for authorization code flow"
|
||||
)
|
||||
response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow")
|
||||
code_challenge: str = Field(..., description="PKCE code challenge")
|
||||
code_challenge_method: Literal["S256"] = Field(
|
||||
"S256", description="PKCE code challenge method, must be S256"
|
||||
)
|
||||
code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method, must be S256")
|
||||
state: str | None = Field(None, description="Optional state parameter")
|
||||
scope: str | None = Field(
|
||||
None,
|
||||
description="Optional scope; if specified, should be "
|
||||
"a space-separated list of scope strings",
|
||||
description="Optional scope; if specified, should be " "a space-separated list of scope strings",
|
||||
)
|
||||
|
||||
|
||||
@@ -57,9 +45,7 @@ class AuthorizationErrorResponse(BaseModel):
|
||||
state: str | None = None
|
||||
|
||||
|
||||
def best_effort_extract_string(
|
||||
key: str, params: None | FormData | QueryParams
|
||||
) -> str | None:
|
||||
def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> str | None:
|
||||
if params is None:
|
||||
return None
|
||||
value = params.get(key)
|
||||
@@ -138,9 +124,7 @@ class AuthorizationHandler:
|
||||
|
||||
if redirect_uri and client:
|
||||
return RedirectResponse(
|
||||
url=construct_redirect_uri(
|
||||
str(redirect_uri), **error_resp.model_dump(exclude_none=True)
|
||||
),
|
||||
url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)),
|
||||
status_code=302,
|
||||
headers={"Cache-Control": "no-store"},
|
||||
)
|
||||
@@ -172,9 +156,7 @@ class AuthorizationHandler:
|
||||
if e["loc"] == ("response_type",) and e["type"] == "literal_error":
|
||||
error = "unsupported_response_type"
|
||||
break
|
||||
return await error_response(
|
||||
error, stringify_pydantic_error(validation_error)
|
||||
)
|
||||
return await error_response(error, stringify_pydantic_error(validation_error))
|
||||
|
||||
# Get client information
|
||||
client = await self.provider.get_client(
|
||||
@@ -229,16 +211,9 @@ class AuthorizationHandler:
|
||||
)
|
||||
except AuthorizeError as e:
|
||||
# Handle authorization errors as defined in RFC 6749 Section 4.1.2.1
|
||||
return await error_response(
|
||||
error=e.error,
|
||||
error_description=e.error_description,
|
||||
)
|
||||
return await error_response(error=e.error, error_description=e.error_description)
|
||||
|
||||
except Exception as validation_error:
|
||||
# Catch-all for unexpected errors
|
||||
logger.exception(
|
||||
"Unexpected error in authorization_handler", exc_info=validation_error
|
||||
)
|
||||
return await error_response(
|
||||
error="server_error", error_description="An unexpected error occurred"
|
||||
)
|
||||
logger.exception("Unexpected error in authorization_handler", exc_info=validation_error)
|
||||
return await error_response(error="server_error", error_description="An unexpected error occurred")
|
||||
|
||||
@@ -10,11 +10,7 @@ from starlette.responses import Response
|
||||
|
||||
from mcp.server.auth.errors import stringify_pydantic_error
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.provider import (
|
||||
OAuthAuthorizationServerProvider,
|
||||
RegistrationError,
|
||||
RegistrationErrorCode,
|
||||
)
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, RegistrationError, RegistrationErrorCode
|
||||
from mcp.server.auth.settings import ClientRegistrationOptions
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
|
||||
|
||||
@@ -60,9 +56,7 @@ class RegistrationHandler:
|
||||
|
||||
if client_metadata.scope is None and self.options.default_scopes is not None:
|
||||
client_metadata.scope = " ".join(self.options.default_scopes)
|
||||
elif (
|
||||
client_metadata.scope is not None and self.options.valid_scopes is not None
|
||||
):
|
||||
elif client_metadata.scope is not None and self.options.valid_scopes is not None:
|
||||
requested_scopes = set(client_metadata.scope.split())
|
||||
valid_scopes = set(self.options.valid_scopes)
|
||||
if not requested_scopes.issubset(valid_scopes):
|
||||
@@ -78,8 +72,7 @@ class RegistrationHandler:
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description="grant_types must be authorization_code "
|
||||
"and refresh_token",
|
||||
error_description="grant_types must be authorization_code " "and refresh_token",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
@@ -122,8 +115,6 @@ class RegistrationHandler:
|
||||
except RegistrationError as e:
|
||||
# Handle registration errors as defined in RFC 7591 Section 3.2.2
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error=e.error, error_description=e.error_description
|
||||
),
|
||||
content=RegistrationErrorResponse(error=e.error, error_description=e.error_description),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
@@ -10,15 +10,8 @@ from mcp.server.auth.errors import (
|
||||
stringify_pydantic_error,
|
||||
)
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.middleware.client_auth import (
|
||||
AuthenticationError,
|
||||
ClientAuthenticator,
|
||||
)
|
||||
from mcp.server.auth.provider import (
|
||||
AccessToken,
|
||||
OAuthAuthorizationServerProvider,
|
||||
RefreshToken,
|
||||
)
|
||||
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
|
||||
from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, RefreshToken
|
||||
|
||||
|
||||
class RevocationRequest(BaseModel):
|
||||
|
||||
@@ -7,19 +7,10 @@ from typing import Annotated, Any, Literal
|
||||
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
|
||||
from starlette.requests import Request
|
||||
|
||||
from mcp.server.auth.errors import (
|
||||
stringify_pydantic_error,
|
||||
)
|
||||
from mcp.server.auth.errors import stringify_pydantic_error
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.middleware.client_auth import (
|
||||
AuthenticationError,
|
||||
ClientAuthenticator,
|
||||
)
|
||||
from mcp.server.auth.provider import (
|
||||
OAuthAuthorizationServerProvider,
|
||||
TokenError,
|
||||
TokenErrorCode,
|
||||
)
|
||||
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenError, TokenErrorCode
|
||||
from mcp.shared.auth import OAuthToken
|
||||
|
||||
|
||||
@@ -27,9 +18,7 @@ class AuthorizationCodeRequest(BaseModel):
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
|
||||
grant_type: Literal["authorization_code"]
|
||||
code: str = Field(..., description="The authorization code")
|
||||
redirect_uri: AnyUrl | None = Field(
|
||||
None, description="Must be the same as redirect URI provided in /authorize"
|
||||
)
|
||||
redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize")
|
||||
client_id: str
|
||||
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
|
||||
client_secret: str | None = None
|
||||
@@ -127,8 +116,7 @@ class TokenHandler:
|
||||
TokenErrorResponse(
|
||||
error="unsupported_grant_type",
|
||||
error_description=(
|
||||
f"Unsupported grant type (supported grant types are "
|
||||
f"{client_info.grant_types})"
|
||||
f"Unsupported grant type (supported grant types are " f"{client_info.grant_types})"
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -137,9 +125,7 @@ class TokenHandler:
|
||||
|
||||
match token_request:
|
||||
case AuthorizationCodeRequest():
|
||||
auth_code = await self.provider.load_authorization_code(
|
||||
client_info, token_request.code
|
||||
)
|
||||
auth_code = await self.provider.load_authorization_code(client_info, token_request.code)
|
||||
if auth_code is None or auth_code.client_id != token_request.client_id:
|
||||
# if code belongs to different client, pretend it doesn't exist
|
||||
return self.response(
|
||||
@@ -169,18 +155,13 @@ class TokenHandler:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_request",
|
||||
error_description=(
|
||||
"redirect_uri did not match the one "
|
||||
"used when creating auth code"
|
||||
),
|
||||
error_description=("redirect_uri did not match the one " "used when creating auth code"),
|
||||
)
|
||||
)
|
||||
|
||||
# Verify PKCE code verifier
|
||||
sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()
|
||||
hashed_code_verifier = (
|
||||
base64.urlsafe_b64encode(sha256).decode().rstrip("=")
|
||||
)
|
||||
hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=")
|
||||
|
||||
if hashed_code_verifier != auth_code.code_challenge:
|
||||
# see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
|
||||
@@ -193,9 +174,7 @@ class TokenHandler:
|
||||
|
||||
try:
|
||||
# Exchange authorization code for tokens
|
||||
tokens = await self.provider.exchange_authorization_code(
|
||||
client_info, auth_code
|
||||
)
|
||||
tokens = await self.provider.exchange_authorization_code(client_info, auth_code)
|
||||
except TokenError as e:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
@@ -205,13 +184,8 @@ class TokenHandler:
|
||||
)
|
||||
|
||||
case RefreshTokenRequest():
|
||||
refresh_token = await self.provider.load_refresh_token(
|
||||
client_info, token_request.refresh_token
|
||||
)
|
||||
if (
|
||||
refresh_token is None
|
||||
or refresh_token.client_id != token_request.client_id
|
||||
):
|
||||
refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token)
|
||||
if refresh_token is None or refresh_token.client_id != token_request.client_id:
|
||||
# if token belongs to different client, pretend it doesn't exist
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
@@ -230,29 +204,20 @@ class TokenHandler:
|
||||
)
|
||||
|
||||
# Parse scopes if provided
|
||||
scopes = (
|
||||
token_request.scope.split(" ")
|
||||
if token_request.scope
|
||||
else refresh_token.scopes
|
||||
)
|
||||
scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes
|
||||
|
||||
for scope in scopes:
|
||||
if scope not in refresh_token.scopes:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_scope",
|
||||
error_description=(
|
||||
f"cannot request scope `{scope}` "
|
||||
"not provided by refresh token"
|
||||
),
|
||||
error_description=(f"cannot request scope `{scope}` " "not provided by refresh token"),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Exchange refresh token for new tokens
|
||||
tokens = await self.provider.exchange_refresh_token(
|
||||
client_info, refresh_token, scopes
|
||||
)
|
||||
tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes)
|
||||
except TokenError as e:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
|
||||
@@ -7,9 +7,7 @@ from mcp.server.auth.provider import AccessToken
|
||||
|
||||
# Create a contextvar to store the authenticated user
|
||||
# The default is None, indicating no authenticated user is present
|
||||
auth_context_var = contextvars.ContextVar[AuthenticatedUser | None](
|
||||
"auth_context", default=None
|
||||
)
|
||||
auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]("auth_context", default=None)
|
||||
|
||||
|
||||
def get_access_token() -> AccessToken | None:
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from starlette.authentication import (
|
||||
AuthCredentials,
|
||||
AuthenticationBackend,
|
||||
SimpleUser,
|
||||
)
|
||||
from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.types import Receive, Scope, Send
|
||||
@@ -35,11 +31,7 @@ class BearerAuthBackend(AuthenticationBackend):
|
||||
|
||||
async def authenticate(self, conn: HTTPConnection):
|
||||
auth_header = next(
|
||||
(
|
||||
conn.headers.get(key)
|
||||
for key in conn.headers
|
||||
if key.lower() == "authorization"
|
||||
),
|
||||
(conn.headers.get(key) for key in conn.headers if key.lower() == "authorization"),
|
||||
None,
|
||||
)
|
||||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||
@@ -87,10 +79,7 @@ class RequireAuthMiddleware:
|
||||
|
||||
for required_scope in self.required_scopes:
|
||||
# auth_credentials should always be provided; this is just paranoia
|
||||
if (
|
||||
auth_credentials is None
|
||||
or required_scope not in auth_credentials.scopes
|
||||
):
|
||||
if auth_credentials is None or required_scope not in auth_credentials.scopes:
|
||||
raise HTTPException(status_code=403, detail="Insufficient scope")
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
@@ -30,9 +30,7 @@ class ClientAuthenticator:
|
||||
"""
|
||||
self.provider = provider
|
||||
|
||||
async def authenticate(
|
||||
self, client_id: str, client_secret: str | None
|
||||
) -> OAuthClientInformationFull:
|
||||
async def authenticate(self, client_id: str, client_secret: str | None) -> OAuthClientInformationFull:
|
||||
# Look up client information
|
||||
client = await self.provider.get_client(client_id)
|
||||
if not client:
|
||||
@@ -47,10 +45,7 @@ class ClientAuthenticator:
|
||||
if client.client_secret != client_secret:
|
||||
raise AuthenticationError("Invalid client_secret")
|
||||
|
||||
if (
|
||||
client.client_secret_expires_at
|
||||
and client.client_secret_expires_at < int(time.time())
|
||||
):
|
||||
if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()):
|
||||
raise AuthenticationError("Client secret has expired")
|
||||
|
||||
return client
|
||||
|
||||
@@ -4,10 +4,7 @@ from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from pydantic import AnyUrl, BaseModel
|
||||
|
||||
from mcp.shared.auth import (
|
||||
OAuthClientInformationFull,
|
||||
OAuthToken,
|
||||
)
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||
|
||||
|
||||
class AuthorizationParams(BaseModel):
|
||||
@@ -96,9 +93,7 @@ RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken)
|
||||
AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken)
|
||||
|
||||
|
||||
class OAuthAuthorizationServerProvider(
|
||||
Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]
|
||||
):
|
||||
class OAuthAuthorizationServerProvider(Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]):
|
||||
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||
"""
|
||||
Retrieves client information by client ID.
|
||||
@@ -129,9 +124,7 @@ class OAuthAuthorizationServerProvider(
|
||||
"""
|
||||
...
|
||||
|
||||
async def authorize(
|
||||
self, client: OAuthClientInformationFull, params: AuthorizationParams
|
||||
) -> str:
|
||||
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
|
||||
"""
|
||||
Called as part of the /authorize endpoint, and returns a URL that the client
|
||||
will be redirected to.
|
||||
@@ -207,9 +200,7 @@ class OAuthAuthorizationServerProvider(
|
||||
"""
|
||||
...
|
||||
|
||||
async def load_refresh_token(
|
||||
self, client: OAuthClientInformationFull, refresh_token: str
|
||||
) -> RefreshTokenT | None:
|
||||
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshTokenT | None:
|
||||
"""
|
||||
Loads a RefreshToken by its token string.
|
||||
|
||||
|
||||
@@ -31,11 +31,7 @@ def validate_issuer_url(url: AnyHttpUrl):
|
||||
"""
|
||||
|
||||
# RFC 8414 requires HTTPS, but we allow localhost HTTP for testing
|
||||
if (
|
||||
url.scheme != "https"
|
||||
and url.host != "localhost"
|
||||
and not url.host.startswith("127.0.0.1")
|
||||
):
|
||||
if url.scheme != "https" and url.host != "localhost" and not url.host.startswith("127.0.0.1"):
|
||||
raise ValueError("Issuer URL must be HTTPS")
|
||||
|
||||
# No fragments or query parameters allowed
|
||||
@@ -73,9 +69,7 @@ def create_auth_routes(
|
||||
) -> list[Route]:
|
||||
validate_issuer_url(issuer_url)
|
||||
|
||||
client_registration_options = (
|
||||
client_registration_options or ClientRegistrationOptions()
|
||||
)
|
||||
client_registration_options = client_registration_options or ClientRegistrationOptions()
|
||||
revocation_options = revocation_options or RevocationOptions()
|
||||
metadata = build_metadata(
|
||||
issuer_url,
|
||||
@@ -177,15 +171,11 @@ def build_metadata(
|
||||
|
||||
# Add registration endpoint if supported
|
||||
if client_registration_options.enabled:
|
||||
metadata.registration_endpoint = AnyHttpUrl(
|
||||
str(issuer_url).rstrip("/") + REGISTRATION_PATH
|
||||
)
|
||||
metadata.registration_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REGISTRATION_PATH)
|
||||
|
||||
# Add revocation endpoint if supported
|
||||
if revocation_options.enabled:
|
||||
metadata.revocation_endpoint = AnyHttpUrl(
|
||||
str(issuer_url).rstrip("/") + REVOCATION_PATH
|
||||
)
|
||||
metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH)
|
||||
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"]
|
||||
|
||||
return metadata
|
||||
|
||||
@@ -15,8 +15,7 @@ class RevocationOptions(BaseModel):
|
||||
class AuthSettings(BaseModel):
|
||||
issuer_url: AnyHttpUrl = Field(
|
||||
...,
|
||||
description="URL advertised as OAuth issuer; this should be the URL the server "
|
||||
"is reachable at",
|
||||
description="URL advertised as OAuth issuer; this should be the URL the server " "is reachable at",
|
||||
)
|
||||
service_documentation_url: AnyHttpUrl | None = None
|
||||
client_registration_options: ClientRegistrationOptions | None = None
|
||||
|
||||
@@ -42,13 +42,9 @@ class AssistantMessage(Message):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
message_validator = TypeAdapter[UserMessage | AssistantMessage](
|
||||
UserMessage | AssistantMessage
|
||||
)
|
||||
message_validator = TypeAdapter[UserMessage | AssistantMessage](UserMessage | AssistantMessage)
|
||||
|
||||
SyncPromptResult = (
|
||||
str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
|
||||
)
|
||||
SyncPromptResult = str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
|
||||
PromptResult = SyncPromptResult | Awaitable[SyncPromptResult]
|
||||
|
||||
|
||||
@@ -56,24 +52,16 @@ class PromptArgument(BaseModel):
|
||||
"""An argument that can be passed to a prompt."""
|
||||
|
||||
name: str = Field(description="Name of the argument")
|
||||
description: str | None = Field(
|
||||
None, description="Description of what the argument does"
|
||||
)
|
||||
required: bool = Field(
|
||||
default=False, description="Whether the argument is required"
|
||||
)
|
||||
description: str | None = Field(None, description="Description of what the argument does")
|
||||
required: bool = Field(default=False, description="Whether the argument is required")
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
"""A prompt template that can be rendered with parameters."""
|
||||
|
||||
name: str = Field(description="Name of the prompt")
|
||||
description: str | None = Field(
|
||||
None, description="Description of what the prompt does"
|
||||
)
|
||||
arguments: list[PromptArgument] | None = Field(
|
||||
None, description="Arguments that can be passed to the prompt"
|
||||
)
|
||||
description: str | None = Field(None, description="Description of what the prompt does")
|
||||
arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt")
|
||||
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
|
||||
|
||||
@classmethod
|
||||
@@ -154,14 +142,10 @@ class Prompt(BaseModel):
|
||||
content = TextContent(type="text", text=msg)
|
||||
messages.append(UserMessage(content=content))
|
||||
else:
|
||||
content = pydantic_core.to_json(
|
||||
msg, fallback=str, indent=2
|
||||
).decode()
|
||||
content = pydantic_core.to_json(msg, fallback=str, indent=2).decode()
|
||||
messages.append(Message(role="user", content=content))
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Could not convert prompt result to message: {msg}"
|
||||
)
|
||||
raise ValueError(f"Could not convert prompt result to message: {msg}")
|
||||
|
||||
return messages
|
||||
except Exception as e:
|
||||
|
||||
@@ -39,9 +39,7 @@ class PromptManager:
|
||||
self._prompts[prompt.name] = prompt
|
||||
return prompt
|
||||
|
||||
async def render_prompt(
|
||||
self, name: str, arguments: dict[str, Any] | None = None
|
||||
) -> list[Message]:
|
||||
async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]:
|
||||
"""Render a prompt by name with arguments."""
|
||||
prompt = self.get_prompt(name)
|
||||
if not prompt:
|
||||
|
||||
@@ -19,13 +19,9 @@ class Resource(BaseModel, abc.ABC):
|
||||
|
||||
model_config = ConfigDict(validate_default=True)
|
||||
|
||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(
|
||||
default=..., description="URI of the resource"
|
||||
)
|
||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(default=..., description="URI of the resource")
|
||||
name: str | None = Field(description="Name of the resource", default=None)
|
||||
description: str | None = Field(
|
||||
description="Description of the resource", default=None
|
||||
)
|
||||
description: str | None = Field(description="Description of the resource", default=None)
|
||||
mime_type: str = Field(
|
||||
default="text/plain",
|
||||
description="MIME type of the resource content",
|
||||
|
||||
@@ -15,18 +15,12 @@ from mcp.server.fastmcp.resources.types import FunctionResource, Resource
|
||||
class ResourceTemplate(BaseModel):
|
||||
"""A template for dynamically creating resources."""
|
||||
|
||||
uri_template: str = Field(
|
||||
description="URI template with parameters (e.g. weather://{city}/current)"
|
||||
)
|
||||
uri_template: str = Field(description="URI template with parameters (e.g. weather://{city}/current)")
|
||||
name: str = Field(description="Name of the resource")
|
||||
description: str | None = Field(description="Description of what the resource does")
|
||||
mime_type: str = Field(
|
||||
default="text/plain", description="MIME type of the resource content"
|
||||
)
|
||||
mime_type: str = Field(default="text/plain", description="MIME type of the resource content")
|
||||
fn: Callable[..., Any] = Field(exclude=True)
|
||||
parameters: dict[str, Any] = Field(
|
||||
description="JSON schema for function parameters"
|
||||
)
|
||||
parameters: dict[str, Any] = Field(description="JSON schema for function parameters")
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
|
||||
@@ -54,9 +54,7 @@ class FunctionResource(Resource):
|
||||
async def read(self) -> str | bytes:
|
||||
"""Read the resource by calling the wrapped function."""
|
||||
try:
|
||||
result = (
|
||||
await self.fn() if inspect.iscoroutinefunction(self.fn) else self.fn()
|
||||
)
|
||||
result = await self.fn() if inspect.iscoroutinefunction(self.fn) else self.fn()
|
||||
if isinstance(result, Resource):
|
||||
return await result.read()
|
||||
elif isinstance(result, bytes):
|
||||
@@ -141,9 +139,7 @@ class HttpResource(Resource):
|
||||
"""A resource that reads from an HTTP endpoint."""
|
||||
|
||||
url: str = Field(description="URL to fetch content from")
|
||||
mime_type: str = Field(
|
||||
default="application/json", description="MIME type of the resource content"
|
||||
)
|
||||
mime_type: str = Field(default="application/json", description="MIME type of the resource content")
|
||||
|
||||
async def read(self) -> str | bytes:
|
||||
"""Read the HTTP content."""
|
||||
@@ -157,15 +153,9 @@ class DirectoryResource(Resource):
|
||||
"""A resource that lists files in a directory."""
|
||||
|
||||
path: Path = Field(description="Path to the directory")
|
||||
recursive: bool = Field(
|
||||
default=False, description="Whether to list files recursively"
|
||||
)
|
||||
pattern: str | None = Field(
|
||||
default=None, description="Optional glob pattern to filter files"
|
||||
)
|
||||
mime_type: str = Field(
|
||||
default="application/json", description="MIME type of the resource content"
|
||||
)
|
||||
recursive: bool = Field(default=False, description="Whether to list files recursively")
|
||||
pattern: str | None = Field(default=None, description="Optional glob pattern to filter files")
|
||||
mime_type: str = Field(default="application/json", description="MIME type of the resource content")
|
||||
|
||||
@pydantic.field_validator("path")
|
||||
@classmethod
|
||||
@@ -184,16 +174,8 @@ class DirectoryResource(Resource):
|
||||
|
||||
try:
|
||||
if self.pattern:
|
||||
return (
|
||||
list(self.path.glob(self.pattern))
|
||||
if not self.recursive
|
||||
else list(self.path.rglob(self.pattern))
|
||||
)
|
||||
return (
|
||||
list(self.path.glob("*"))
|
||||
if not self.recursive
|
||||
else list(self.path.rglob("*"))
|
||||
)
|
||||
return list(self.path.glob(self.pattern)) if not self.recursive else list(self.path.rglob(self.pattern))
|
||||
return list(self.path.glob("*")) if not self.recursive else list(self.path.rglob("*"))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error listing directory {self.path}: {e}")
|
||||
|
||||
|
||||
@@ -97,9 +97,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
||||
|
||||
# StreamableHTTP settings
|
||||
json_response: bool = False
|
||||
stateless_http: bool = (
|
||||
False # If True, uses true stateless mode (new transport per request)
|
||||
)
|
||||
stateless_http: bool = False # If True, uses true stateless mode (new transport per request)
|
||||
|
||||
# resource settings
|
||||
warn_on_duplicate_resources: bool = True
|
||||
@@ -115,9 +113,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
||||
description="List of dependencies to install in the server environment",
|
||||
)
|
||||
|
||||
lifespan: (
|
||||
Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None
|
||||
) = Field(None, description="Lifespan context manager")
|
||||
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None = Field(
|
||||
None, description="Lifespan context manager"
|
||||
)
|
||||
|
||||
auth: AuthSettings | None = None
|
||||
|
||||
@@ -125,9 +123,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
||||
def lifespan_wrapper(
|
||||
app: FastMCP,
|
||||
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
|
||||
) -> Callable[
|
||||
[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]
|
||||
]:
|
||||
) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]]:
|
||||
@asynccontextmanager
|
||||
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
|
||||
async with lifespan(app) as context:
|
||||
@@ -141,8 +137,7 @@ class FastMCP:
|
||||
self,
|
||||
name: str | None = None,
|
||||
instructions: str | None = None,
|
||||
auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
| None = None,
|
||||
auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None,
|
||||
event_store: EventStore | None = None,
|
||||
*,
|
||||
tools: list[Tool] | None = None,
|
||||
@@ -153,31 +148,18 @@ class FastMCP:
|
||||
self._mcp_server = MCPServer(
|
||||
name=name or "FastMCP",
|
||||
instructions=instructions,
|
||||
lifespan=(
|
||||
lifespan_wrapper(self, self.settings.lifespan)
|
||||
if self.settings.lifespan
|
||||
else default_lifespan
|
||||
),
|
||||
)
|
||||
self._tool_manager = ToolManager(
|
||||
tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
|
||||
)
|
||||
self._resource_manager = ResourceManager(
|
||||
warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources
|
||||
)
|
||||
self._prompt_manager = PromptManager(
|
||||
warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts
|
||||
lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan),
|
||||
)
|
||||
self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools)
|
||||
self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources)
|
||||
self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts)
|
||||
if (self.settings.auth is not None) != (auth_server_provider is not None):
|
||||
# TODO: after we support separate authorization servers (see
|
||||
# https://github.com/modelcontextprotocol/modelcontextprotocol/pull/284)
|
||||
# we should validate that if auth is enabled, we have either an
|
||||
# auth_server_provider to host our own authorization server,
|
||||
# OR the URL of a 3rd party authorization server.
|
||||
raise ValueError(
|
||||
"settings.auth must be specified if and only if auth_server_provider "
|
||||
"is specified"
|
||||
)
|
||||
raise ValueError("settings.auth must be specified if and only if auth_server_provider " "is specified")
|
||||
self._auth_server_provider = auth_server_provider
|
||||
self._event_store = event_store
|
||||
self._custom_starlette_routes: list[Route] = []
|
||||
@@ -340,9 +322,7 @@ class FastMCP:
|
||||
description: Optional description of what the tool does
|
||||
annotations: Optional ToolAnnotations providing additional tool information
|
||||
"""
|
||||
self._tool_manager.add_tool(
|
||||
fn, name=name, description=description, annotations=annotations
|
||||
)
|
||||
self._tool_manager.add_tool(fn, name=name, description=description, annotations=annotations)
|
||||
|
||||
def tool(
|
||||
self,
|
||||
@@ -379,14 +359,11 @@ class FastMCP:
|
||||
# Check if user passed function directly instead of calling decorator
|
||||
if callable(name):
|
||||
raise TypeError(
|
||||
"The @tool decorator was used incorrectly. "
|
||||
"Did you forget to call it? Use @tool() instead of @tool"
|
||||
"The @tool decorator was used incorrectly. " "Did you forget to call it? Use @tool() instead of @tool"
|
||||
)
|
||||
|
||||
def decorator(fn: AnyFunction) -> AnyFunction:
|
||||
self.add_tool(
|
||||
fn, name=name, description=description, annotations=annotations
|
||||
)
|
||||
self.add_tool(fn, name=name, description=description, annotations=annotations)
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
@@ -462,8 +439,7 @@ class FastMCP:
|
||||
|
||||
if uri_params != func_params:
|
||||
raise ValueError(
|
||||
f"Mismatch between URI parameters {uri_params} "
|
||||
f"and function parameters {func_params}"
|
||||
f"Mismatch between URI parameters {uri_params} " f"and function parameters {func_params}"
|
||||
)
|
||||
|
||||
# Register as template
|
||||
@@ -496,9 +472,7 @@ class FastMCP:
|
||||
"""
|
||||
self._prompt_manager.add_prompt(prompt)
|
||||
|
||||
def prompt(
|
||||
self, name: str | None = None, description: str | None = None
|
||||
) -> Callable[[AnyFunction], AnyFunction]:
|
||||
def prompt(self, name: str | None = None, description: str | None = None) -> Callable[[AnyFunction], AnyFunction]:
|
||||
"""Decorator to register a prompt.
|
||||
|
||||
Args:
|
||||
@@ -665,9 +639,7 @@ class FastMCP:
|
||||
self.settings.mount_path = mount_path
|
||||
|
||||
# Create normalized endpoint considering the mount path
|
||||
normalized_message_endpoint = self._normalize_path(
|
||||
self.settings.mount_path, self.settings.message_path
|
||||
)
|
||||
normalized_message_endpoint = self._normalize_path(self.settings.mount_path, self.settings.message_path)
|
||||
|
||||
# Set up auth context and dependencies
|
||||
|
||||
@@ -764,9 +736,7 @@ class FastMCP:
|
||||
routes.extend(self._custom_starlette_routes)
|
||||
|
||||
# Create Starlette app with routes and middleware
|
||||
return Starlette(
|
||||
debug=self.settings.debug, routes=routes, middleware=middleware
|
||||
)
|
||||
return Starlette(debug=self.settings.debug, routes=routes, middleware=middleware)
|
||||
|
||||
def streamable_http_app(self) -> Starlette:
|
||||
"""Return an instance of the StreamableHTTP server app."""
|
||||
@@ -783,9 +753,7 @@ class FastMCP:
|
||||
)
|
||||
|
||||
# Create the ASGI handler
|
||||
async def handle_streamable_http(
|
||||
scope: Scope, receive: Receive, send: Send
|
||||
) -> None:
|
||||
async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await self.session_manager.handle_request(scope, receive, send)
|
||||
|
||||
# Create routes
|
||||
@@ -861,9 +829,7 @@ class FastMCP:
|
||||
for prompt in prompts
|
||||
]
|
||||
|
||||
async def get_prompt(
|
||||
self, name: str, arguments: dict[str, Any] | None = None
|
||||
) -> GetPromptResult:
|
||||
async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult:
|
||||
"""Get a prompt by name with arguments."""
|
||||
try:
|
||||
messages = await self._prompt_manager.render_prompt(name, arguments)
|
||||
@@ -936,9 +902,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_context: (
|
||||
RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
|
||||
) = None,
|
||||
request_context: (RequestContext[ServerSessionT, LifespanContextT, RequestT] | None) = None,
|
||||
fastmcp: FastMCP | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
@@ -962,9 +926,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
||||
raise ValueError("Context is not available outside of a request")
|
||||
return self._request_context
|
||||
|
||||
async def report_progress(
|
||||
self, progress: float, total: float | None = None, message: str | None = None
|
||||
) -> None:
|
||||
async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None:
|
||||
"""Report progress for the current operation.
|
||||
|
||||
Args:
|
||||
@@ -972,11 +934,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
||||
total: Optional total value e.g. 100
|
||||
message: Optional message e.g. Starting render...
|
||||
"""
|
||||
progress_token = (
|
||||
self.request_context.meta.progressToken
|
||||
if self.request_context.meta
|
||||
else None
|
||||
)
|
||||
progress_token = self.request_context.meta.progressToken if self.request_context.meta else None
|
||||
|
||||
if progress_token is None:
|
||||
return
|
||||
@@ -997,9 +955,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
||||
Returns:
|
||||
The resource content as either text or bytes
|
||||
"""
|
||||
assert (
|
||||
self._fastmcp is not None
|
||||
), "Context is not available outside of a request"
|
||||
assert self._fastmcp is not None, "Context is not available outside of a request"
|
||||
return await self._fastmcp.read_resource(uri)
|
||||
|
||||
async def log(
|
||||
@@ -1027,11 +983,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
||||
@property
|
||||
def client_id(self) -> str | None:
|
||||
"""Get the client ID if available."""
|
||||
return (
|
||||
getattr(self.request_context.meta, "client_id", None)
|
||||
if self.request_context.meta
|
||||
else None
|
||||
)
|
||||
return getattr(self.request_context.meta, "client_id", None) if self.request_context.meta else None
|
||||
|
||||
@property
|
||||
def request_id(self) -> str:
|
||||
|
||||
@@ -25,16 +25,11 @@ class Tool(BaseModel):
|
||||
description: str = Field(description="Description of what the tool does")
|
||||
parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
|
||||
fn_metadata: FuncMetadata = Field(
|
||||
description="Metadata about the function including a pydantic model for tool"
|
||||
" arguments"
|
||||
description="Metadata about the function including a pydantic model for tool" " arguments"
|
||||
)
|
||||
is_async: bool = Field(description="Whether the tool is async")
|
||||
context_kwarg: str | None = Field(
|
||||
None, description="Name of the kwarg that should receive context"
|
||||
)
|
||||
annotations: ToolAnnotations | None = Field(
|
||||
None, description="Optional annotations for the tool"
|
||||
)
|
||||
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
|
||||
annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool")
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
@@ -93,9 +88,7 @@ class Tool(BaseModel):
|
||||
self.fn,
|
||||
self.is_async,
|
||||
arguments,
|
||||
{self.context_kwarg: context}
|
||||
if self.context_kwarg is not None
|
||||
else None,
|
||||
{self.context_kwarg: context} if self.context_kwarg is not None else None,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Error executing tool {self.name}: {e}") from e
|
||||
|
||||
@@ -50,9 +50,7 @@ class ToolManager:
|
||||
annotations: ToolAnnotations | None = None,
|
||||
) -> Tool:
|
||||
"""Add a tool to the server."""
|
||||
tool = Tool.from_function(
|
||||
fn, name=name, description=description, annotations=annotations
|
||||
)
|
||||
tool = Tool.from_function(fn, name=name, description=description, annotations=annotations)
|
||||
existing = self._tools.get(tool.name)
|
||||
if existing:
|
||||
if self.warn_on_duplicate_tools:
|
||||
|
||||
@@ -102,9 +102,7 @@ class FuncMetadata(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
def func_metadata(
|
||||
func: Callable[..., Any], skip_names: Sequence[str] = ()
|
||||
) -> FuncMetadata:
|
||||
def func_metadata(func: Callable[..., Any], skip_names: Sequence[str] = ()) -> FuncMetadata:
|
||||
"""Given a function, return metadata including a pydantic model representing its
|
||||
signature.
|
||||
|
||||
@@ -131,9 +129,7 @@ def func_metadata(
|
||||
globalns = getattr(func, "__globals__", {})
|
||||
for param in params.values():
|
||||
if param.name.startswith("_"):
|
||||
raise InvalidSignature(
|
||||
f"Parameter {param.name} of {func.__name__} cannot start with '_'"
|
||||
)
|
||||
raise InvalidSignature(f"Parameter {param.name} of {func.__name__} cannot start with '_'")
|
||||
if param.name in skip_names:
|
||||
continue
|
||||
annotation = param.annotation
|
||||
@@ -142,11 +138,7 @@ def func_metadata(
|
||||
if annotation is None:
|
||||
annotation = Annotated[
|
||||
None,
|
||||
Field(
|
||||
default=param.default
|
||||
if param.default is not inspect.Parameter.empty
|
||||
else PydanticUndefined
|
||||
),
|
||||
Field(default=param.default if param.default is not inspect.Parameter.empty else PydanticUndefined),
|
||||
]
|
||||
|
||||
# Untyped field
|
||||
@@ -160,9 +152,7 @@ def func_metadata(
|
||||
|
||||
field_info = FieldInfo.from_annotated_attribute(
|
||||
_get_typed_annotation(annotation, globalns),
|
||||
param.default
|
||||
if param.default is not inspect.Parameter.empty
|
||||
else PydanticUndefined,
|
||||
param.default if param.default is not inspect.Parameter.empty else PydanticUndefined,
|
||||
)
|
||||
dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info)
|
||||
continue
|
||||
@@ -177,9 +167,7 @@ def func_metadata(
|
||||
|
||||
|
||||
def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
|
||||
def try_eval_type(
|
||||
value: Any, globalns: dict[str, Any], localns: dict[str, Any]
|
||||
) -> tuple[Any, bool]:
|
||||
def try_eval_type(value: Any, globalns: dict[str, Any], localns: dict[str, Any]) -> tuple[Any, bool]:
|
||||
try:
|
||||
return eval_type_backport(value, globalns, localns), True
|
||||
except NameError:
|
||||
|
||||
@@ -95,9 +95,7 @@ LifespanResultT = TypeVar("LifespanResultT")
|
||||
RequestT = TypeVar("RequestT", default=Any)
|
||||
|
||||
# This will be properly typed in each Server instance's context
|
||||
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = (
|
||||
contextvars.ContextVar("request_ctx")
|
||||
)
|
||||
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx")
|
||||
|
||||
|
||||
class NotificationOptions:
|
||||
@@ -140,9 +138,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
self.version = version
|
||||
self.instructions = instructions
|
||||
self.lifespan = lifespan
|
||||
self.request_handlers: dict[
|
||||
type, Callable[..., Awaitable[types.ServerResult]]
|
||||
] = {
|
||||
self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = {
|
||||
types.PingRequest: _ping_handler,
|
||||
}
|
||||
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
||||
@@ -189,9 +185,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
|
||||
# Set prompt capabilities if handler exists
|
||||
if types.ListPromptsRequest in self.request_handlers:
|
||||
prompts_capability = types.PromptsCapability(
|
||||
listChanged=notification_options.prompts_changed
|
||||
)
|
||||
prompts_capability = types.PromptsCapability(listChanged=notification_options.prompts_changed)
|
||||
|
||||
# Set resource capabilities if handler exists
|
||||
if types.ListResourcesRequest in self.request_handlers:
|
||||
@@ -201,9 +195,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
|
||||
# Set tool capabilities if handler exists
|
||||
if types.ListToolsRequest in self.request_handlers:
|
||||
tools_capability = types.ToolsCapability(
|
||||
listChanged=notification_options.tools_changed
|
||||
)
|
||||
tools_capability = types.ToolsCapability(listChanged=notification_options.tools_changed)
|
||||
|
||||
# Set logging capabilities if handler exists
|
||||
if types.SetLevelRequest in self.request_handlers:
|
||||
@@ -239,9 +231,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
|
||||
def get_prompt(self):
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[str, dict[str, str] | None], Awaitable[types.GetPromptResult]
|
||||
],
|
||||
func: Callable[[str, dict[str, str] | None], Awaitable[types.GetPromptResult]],
|
||||
):
|
||||
logger.debug("Registering handler for GetPromptRequest")
|
||||
|
||||
@@ -260,9 +250,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
|
||||
async def handler(_: Any):
|
||||
resources = await func()
|
||||
return types.ServerResult(
|
||||
types.ListResourcesResult(resources=resources)
|
||||
)
|
||||
return types.ServerResult(types.ListResourcesResult(resources=resources))
|
||||
|
||||
self.request_handlers[types.ListResourcesRequest] = handler
|
||||
return func
|
||||
@@ -275,9 +263,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
|
||||
async def handler(_: Any):
|
||||
templates = await func()
|
||||
return types.ServerResult(
|
||||
types.ListResourceTemplatesResult(resourceTemplates=templates)
|
||||
)
|
||||
return types.ServerResult(types.ListResourceTemplatesResult(resourceTemplates=templates))
|
||||
|
||||
self.request_handlers[types.ListResourceTemplatesRequest] = handler
|
||||
return func
|
||||
@@ -286,9 +272,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
|
||||
def read_resource(self):
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]]
|
||||
],
|
||||
func: Callable[[AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]]],
|
||||
):
|
||||
logger.debug("Registering handler for ReadResourceRequest")
|
||||
|
||||
@@ -323,8 +307,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
content = create_content(data, None)
|
||||
case Iterable() as contents:
|
||||
contents_list = [
|
||||
create_content(content_item.content, content_item.mime_type)
|
||||
for content_item in contents
|
||||
create_content(content_item.content, content_item.mime_type) for content_item in contents
|
||||
]
|
||||
return types.ServerResult(
|
||||
types.ReadResourceResult(
|
||||
@@ -332,9 +315,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
)
|
||||
)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unexpected return type from read_resource: {type(result)}"
|
||||
)
|
||||
raise ValueError(f"Unexpected return type from read_resource: {type(result)}")
|
||||
|
||||
return types.ServerResult(
|
||||
types.ReadResourceResult(
|
||||
@@ -404,12 +385,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
func: Callable[
|
||||
...,
|
||||
Awaitable[
|
||||
Iterable[
|
||||
types.TextContent
|
||||
| types.ImageContent
|
||||
| types.AudioContent
|
||||
| types.EmbeddedResource
|
||||
]
|
||||
Iterable[types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource]
|
||||
],
|
||||
],
|
||||
):
|
||||
@@ -418,9 +394,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
async def handler(req: types.CallToolRequest):
|
||||
try:
|
||||
results = await func(req.params.name, (req.params.arguments or {}))
|
||||
return types.ServerResult(
|
||||
types.CallToolResult(content=list(results), isError=False)
|
||||
)
|
||||
return types.ServerResult(types.CallToolResult(content=list(results), isError=False))
|
||||
except Exception as e:
|
||||
return types.ServerResult(
|
||||
types.CallToolResult(
|
||||
@@ -436,9 +410,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
|
||||
def progress_notification(self):
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[str | int, float, float | None, str | None], Awaitable[None]
|
||||
],
|
||||
func: Callable[[str | int, float, float | None, str | None], Awaitable[None]],
|
||||
):
|
||||
logger.debug("Registering handler for ProgressNotification")
|
||||
|
||||
@@ -525,9 +497,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
message: RequestResponder[types.ClientRequest, types.ServerResult]
|
||||
| types.ClientNotification
|
||||
| Exception,
|
||||
message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception,
|
||||
session: ServerSession,
|
||||
lifespan_context: LifespanResultT,
|
||||
raise_exceptions: bool = False,
|
||||
@@ -535,20 +505,14 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
# TODO(Marcelo): We should be checking if message is Exception here.
|
||||
match message: # type: ignore[reportMatchNotExhaustive]
|
||||
case (
|
||||
RequestResponder(request=types.ClientRequest(root=req)) as responder
|
||||
):
|
||||
case RequestResponder(request=types.ClientRequest(root=req)) as responder:
|
||||
with responder:
|
||||
await self._handle_request(
|
||||
message, req, session, lifespan_context, raise_exceptions
|
||||
)
|
||||
await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
|
||||
case types.ClientNotification(root=notify):
|
||||
await self._handle_notification(notify)
|
||||
|
||||
for warning in w:
|
||||
logger.info(
|
||||
"Warning: %s: %s", warning.category.__name__, warning.message
|
||||
)
|
||||
logger.info("Warning: %s: %s", warning.category.__name__, warning.message)
|
||||
|
||||
async def _handle_request(
|
||||
self,
|
||||
@@ -566,9 +530,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
try:
|
||||
# Extract request context from message metadata
|
||||
request_data = None
|
||||
if message.message_metadata is not None and isinstance(
|
||||
message.message_metadata, ServerMessageMetadata
|
||||
):
|
||||
if message.message_metadata is not None and isinstance(message.message_metadata, ServerMessageMetadata):
|
||||
request_data = message.message_metadata.request_context
|
||||
|
||||
# Set our global state that can be retrieved via
|
||||
|
||||
@@ -64,9 +64,7 @@ class InitializationState(Enum):
|
||||
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
|
||||
|
||||
ServerRequestResponder = (
|
||||
RequestResponder[types.ClientRequest, types.ServerResult]
|
||||
| types.ClientNotification
|
||||
| Exception
|
||||
RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception
|
||||
)
|
||||
|
||||
|
||||
@@ -89,22 +87,16 @@ class ServerSession(
|
||||
init_options: InitializationOptions,
|
||||
stateless: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
read_stream, write_stream, types.ClientRequest, types.ClientNotification
|
||||
)
|
||||
super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification)
|
||||
self._initialization_state = (
|
||||
InitializationState.Initialized
|
||||
if stateless
|
||||
else InitializationState.NotInitialized
|
||||
InitializationState.Initialized if stateless else InitializationState.NotInitialized
|
||||
)
|
||||
|
||||
self._init_options = init_options
|
||||
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
|
||||
anyio.create_memory_object_stream[ServerRequestResponder](0)
|
||||
)
|
||||
self._exit_stack.push_async_callback(
|
||||
lambda: self._incoming_message_stream_reader.aclose()
|
||||
)
|
||||
self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[
|
||||
ServerRequestResponder
|
||||
](0)
|
||||
self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose())
|
||||
|
||||
@property
|
||||
def client_params(self) -> types.InitializeRequestParams | None:
|
||||
@@ -134,10 +126,7 @@ class ServerSession(
|
||||
return False
|
||||
# Check each experimental capability
|
||||
for exp_key, exp_value in capability.experimental.items():
|
||||
if (
|
||||
exp_key not in client_caps.experimental
|
||||
or client_caps.experimental[exp_key] != exp_value
|
||||
):
|
||||
if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value:
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -146,9 +135,7 @@ class ServerSession(
|
||||
async with self._incoming_message_stream_writer:
|
||||
await super()._receive_loop()
|
||||
|
||||
async def _received_request(
|
||||
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
|
||||
):
|
||||
async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]):
|
||||
match responder.request.root:
|
||||
case types.InitializeRequest(params=params):
|
||||
requested_version = params.protocolVersion
|
||||
@@ -172,13 +159,9 @@ class ServerSession(
|
||||
)
|
||||
case _:
|
||||
if self._initialization_state != InitializationState.Initialized:
|
||||
raise RuntimeError(
|
||||
"Received request before initialization was complete"
|
||||
)
|
||||
raise RuntimeError("Received request before initialization was complete")
|
||||
|
||||
async def _received_notification(
|
||||
self, notification: types.ClientNotification
|
||||
) -> None:
|
||||
async def _received_notification(self, notification: types.ClientNotification) -> None:
|
||||
# Need this to avoid ASYNC910
|
||||
await anyio.lowlevel.checkpoint()
|
||||
match notification.root:
|
||||
@@ -186,9 +169,7 @@ class ServerSession(
|
||||
self._initialization_state = InitializationState.Initialized
|
||||
case _:
|
||||
if self._initialization_state != InitializationState.Initialized:
|
||||
raise RuntimeError(
|
||||
"Received notification before initialization was complete"
|
||||
)
|
||||
raise RuntimeError("Received notification before initialization was complete")
|
||||
|
||||
async def send_log_message(
|
||||
self,
|
||||
|
||||
@@ -116,20 +116,14 @@ class SseServerTransport:
|
||||
full_message_path_for_client = root_path.rstrip("/") + self._endpoint
|
||||
|
||||
# This is the URI (path + query) the client will use to POST messages.
|
||||
client_post_uri_data = (
|
||||
f"{quote(full_message_path_for_client)}?session_id={session_id.hex}"
|
||||
)
|
||||
client_post_uri_data = f"{quote(full_message_path_for_client)}?session_id={session_id.hex}"
|
||||
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
||||
dict[str, Any]
|
||||
](0)
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, Any]](0)
|
||||
|
||||
async def sse_writer():
|
||||
logger.debug("Starting SSE writer")
|
||||
async with sse_stream_writer, write_stream_reader:
|
||||
await sse_stream_writer.send(
|
||||
{"event": "endpoint", "data": client_post_uri_data}
|
||||
)
|
||||
await sse_stream_writer.send({"event": "endpoint", "data": client_post_uri_data})
|
||||
logger.debug(f"Sent endpoint event: {client_post_uri_data}")
|
||||
|
||||
async for session_message in write_stream_reader:
|
||||
@@ -137,9 +131,7 @@ class SseServerTransport:
|
||||
await sse_stream_writer.send(
|
||||
{
|
||||
"event": "message",
|
||||
"data": session_message.message.model_dump_json(
|
||||
by_alias=True, exclude_none=True
|
||||
),
|
||||
"data": session_message.message.model_dump_json(by_alias=True, exclude_none=True),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -151,9 +143,9 @@ class SseServerTransport:
|
||||
In this case we close our side of the streams to signal the client that
|
||||
the connection has been closed.
|
||||
"""
|
||||
await EventSourceResponse(
|
||||
content=sse_stream_reader, data_sender_callable=sse_writer
|
||||
)(scope, receive, send)
|
||||
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
|
||||
scope, receive, send
|
||||
)
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream_reader.aclose()
|
||||
logging.debug(f"Client session disconnected {session_id}")
|
||||
@@ -164,9 +156,7 @@ class SseServerTransport:
|
||||
logger.debug("Yielding read and write streams")
|
||||
yield (read_stream, write_stream)
|
||||
|
||||
async def handle_post_message(
|
||||
self, scope: Scope, receive: Receive, send: Send
|
||||
) -> None:
|
||||
async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
logger.debug("Handling POST message")
|
||||
request = Request(scope, receive)
|
||||
|
||||
|
||||
@@ -76,9 +76,7 @@ async def stdio_server(
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
json = session_message.message.model_dump_json(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
json = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
await stdout.write(json + "\n")
|
||||
await stdout.flush()
|
||||
except anyio.ClosedResourceError:
|
||||
|
||||
@@ -82,9 +82,7 @@ class EventStore(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def store_event(
|
||||
self, stream_id: StreamId, message: JSONRPCMessage
|
||||
) -> EventId:
|
||||
async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId:
|
||||
"""
|
||||
Stores an event for later retrieval.
|
||||
|
||||
@@ -125,9 +123,7 @@ class StreamableHTTPServerTransport:
|
||||
"""
|
||||
|
||||
# Server notification streams for POST requests as well as standalone SSE stream
|
||||
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = (
|
||||
None
|
||||
)
|
||||
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = None
|
||||
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
|
||||
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
|
||||
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
|
||||
@@ -153,12 +149,8 @@ class StreamableHTTPServerTransport:
|
||||
Raises:
|
||||
ValueError: If the session ID contains invalid characters.
|
||||
"""
|
||||
if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch(
|
||||
mcp_session_id
|
||||
):
|
||||
raise ValueError(
|
||||
"Session ID must only contain visible ASCII characters (0x21-0x7E)"
|
||||
)
|
||||
if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch(mcp_session_id):
|
||||
raise ValueError("Session ID must only contain visible ASCII characters (0x21-0x7E)")
|
||||
|
||||
self.mcp_session_id = mcp_session_id
|
||||
self.is_json_response_enabled = is_json_response_enabled
|
||||
@@ -218,9 +210,7 @@ class StreamableHTTPServerTransport:
|
||||
response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
||||
|
||||
return Response(
|
||||
response_message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
if response_message
|
||||
else None,
|
||||
response_message.model_dump_json(by_alias=True, exclude_none=True) if response_message else None,
|
||||
status_code=status_code,
|
||||
headers=response_headers,
|
||||
)
|
||||
@@ -233,9 +223,7 @@ class StreamableHTTPServerTransport:
|
||||
"""Create event data dictionary from an EventMessage."""
|
||||
event_data = {
|
||||
"event": "message",
|
||||
"data": event_message.message.model_dump_json(
|
||||
by_alias=True, exclude_none=True
|
||||
),
|
||||
"data": event_message.message.model_dump_json(by_alias=True, exclude_none=True),
|
||||
}
|
||||
|
||||
# If an event ID was provided, include it
|
||||
@@ -283,42 +271,29 @@ class StreamableHTTPServerTransport:
|
||||
accept_header = request.headers.get("accept", "")
|
||||
accept_types = [media_type.strip() for media_type in accept_header.split(",")]
|
||||
|
||||
has_json = any(
|
||||
media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types
|
||||
)
|
||||
has_sse = any(
|
||||
media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types
|
||||
)
|
||||
has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types)
|
||||
has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types)
|
||||
|
||||
return has_json, has_sse
|
||||
|
||||
def _check_content_type(self, request: Request) -> bool:
|
||||
"""Check if the request has the correct Content-Type."""
|
||||
content_type = request.headers.get("content-type", "")
|
||||
content_type_parts = [
|
||||
part.strip() for part in content_type.split(";")[0].split(",")
|
||||
]
|
||||
content_type_parts = [part.strip() for part in content_type.split(";")[0].split(",")]
|
||||
|
||||
return any(part == CONTENT_TYPE_JSON for part in content_type_parts)
|
||||
|
||||
async def _handle_post_request(
|
||||
self, scope: Scope, request: Request, receive: Receive, send: Send
|
||||
) -> None:
|
||||
async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None:
|
||||
"""Handle POST requests containing JSON-RPC messages."""
|
||||
writer = self._read_stream_writer
|
||||
if writer is None:
|
||||
raise ValueError(
|
||||
"No read stream writer available. Ensure connect() is called first."
|
||||
)
|
||||
raise ValueError("No read stream writer available. Ensure connect() is called first.")
|
||||
try:
|
||||
# Check Accept headers
|
||||
has_json, has_sse = self._check_accept_headers(request)
|
||||
if not (has_json and has_sse):
|
||||
response = self._create_error_response(
|
||||
(
|
||||
"Not Acceptable: Client must accept both application/json and "
|
||||
"text/event-stream"
|
||||
),
|
||||
("Not Acceptable: Client must accept both application/json and " "text/event-stream"),
|
||||
HTTPStatus.NOT_ACCEPTABLE,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
@@ -346,9 +321,7 @@ class StreamableHTTPServerTransport:
|
||||
try:
|
||||
raw_message = json.loads(body)
|
||||
except json.JSONDecodeError as e:
|
||||
response = self._create_error_response(
|
||||
f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR
|
||||
)
|
||||
response = self._create_error_response(f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
@@ -364,10 +337,7 @@ class StreamableHTTPServerTransport:
|
||||
return
|
||||
|
||||
# Check if this is an initialization request
|
||||
is_initialization_request = (
|
||||
isinstance(message.root, JSONRPCRequest)
|
||||
and message.root.method == "initialize"
|
||||
)
|
||||
is_initialization_request = isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
||||
|
||||
if is_initialization_request:
|
||||
# Check if the server already has an established session
|
||||
@@ -406,9 +376,7 @@ class StreamableHTTPServerTransport:
|
||||
# Extract the request ID outside the try block for proper scope
|
||||
request_id = str(message.root.id)
|
||||
# Register this stream for the request ID
|
||||
self._request_streams[request_id] = anyio.create_memory_object_stream[
|
||||
EventMessage
|
||||
](0)
|
||||
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0)
|
||||
request_stream_reader = self._request_streams[request_id][1]
|
||||
|
||||
if self.is_json_response_enabled:
|
||||
@@ -424,16 +392,12 @@ class StreamableHTTPServerTransport:
|
||||
# Use similar approach to SSE writer for consistency
|
||||
async for event_message in request_stream_reader:
|
||||
# If it's a response, this is what we're waiting for
|
||||
if isinstance(
|
||||
event_message.message.root, JSONRPCResponse | JSONRPCError
|
||||
):
|
||||
if isinstance(event_message.message.root, JSONRPCResponse | JSONRPCError):
|
||||
response_message = event_message.message
|
||||
break
|
||||
# For notifications and request, keep waiting
|
||||
else:
|
||||
logger.debug(
|
||||
f"received: {event_message.message.root.method}"
|
||||
)
|
||||
logger.debug(f"received: {event_message.message.root.method}")
|
||||
|
||||
# At this point we should have a response
|
||||
if response_message:
|
||||
@@ -442,9 +406,7 @@ class StreamableHTTPServerTransport:
|
||||
await response(scope, receive, send)
|
||||
else:
|
||||
# This shouldn't happen in normal operation
|
||||
logger.error(
|
||||
"No response message received before stream closed"
|
||||
)
|
||||
logger.error("No response message received before stream closed")
|
||||
response = self._create_error_response(
|
||||
"Error processing request: No response received",
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
@@ -462,9 +424,7 @@ class StreamableHTTPServerTransport:
|
||||
await self._clean_up_memory_streams(request_id)
|
||||
else:
|
||||
# Create SSE stream
|
||||
sse_stream_writer, sse_stream_reader = (
|
||||
anyio.create_memory_object_stream[dict[str, str]](0)
|
||||
)
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
|
||||
|
||||
async def sse_writer():
|
||||
# Get the request ID from the incoming request message
|
||||
@@ -495,11 +455,7 @@ class StreamableHTTPServerTransport:
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": CONTENT_TYPE_SSE,
|
||||
**(
|
||||
{MCP_SESSION_ID_HEADER: self.mcp_session_id}
|
||||
if self.mcp_session_id
|
||||
else {}
|
||||
),
|
||||
**({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}),
|
||||
}
|
||||
response = EventSourceResponse(
|
||||
content=sse_stream_reader,
|
||||
@@ -544,9 +500,7 @@ class StreamableHTTPServerTransport:
|
||||
"""
|
||||
writer = self._read_stream_writer
|
||||
if writer is None:
|
||||
raise ValueError(
|
||||
"No read stream writer available. Ensure connect() is called first."
|
||||
)
|
||||
raise ValueError("No read stream writer available. Ensure connect() is called first.")
|
||||
|
||||
# Validate Accept header - must include text/event-stream
|
||||
_, has_sse = self._check_accept_headers(request)
|
||||
@@ -585,17 +539,13 @@ class StreamableHTTPServerTransport:
|
||||
return
|
||||
|
||||
# Create SSE stream
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
||||
dict[str, str]
|
||||
](0)
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
|
||||
|
||||
async def standalone_sse_writer():
|
||||
try:
|
||||
# Create a standalone message stream for server-initiated messages
|
||||
|
||||
self._request_streams[GET_STREAM_KEY] = (
|
||||
anyio.create_memory_object_stream[EventMessage](0)
|
||||
)
|
||||
self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](0)
|
||||
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]
|
||||
|
||||
async with sse_stream_writer, standalone_stream_reader:
|
||||
@@ -732,9 +682,7 @@ class StreamableHTTPServerTransport:
|
||||
|
||||
return True
|
||||
|
||||
async def _replay_events(
|
||||
self, last_event_id: str, request: Request, send: Send
|
||||
) -> None:
|
||||
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
|
||||
"""
|
||||
Replays events that would have been sent after the specified event ID.
|
||||
Only used when resumability is enabled.
|
||||
@@ -754,9 +702,7 @@ class StreamableHTTPServerTransport:
|
||||
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
||||
|
||||
# Create SSE stream for replay
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
||||
dict[str, str]
|
||||
](0)
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
|
||||
|
||||
async def replay_sender():
|
||||
try:
|
||||
@@ -767,15 +713,11 @@ class StreamableHTTPServerTransport:
|
||||
await sse_stream_writer.send(event_data)
|
||||
|
||||
# Replay past events and get the stream ID
|
||||
stream_id = await event_store.replay_events_after(
|
||||
last_event_id, send_event
|
||||
)
|
||||
stream_id = await event_store.replay_events_after(last_event_id, send_event)
|
||||
|
||||
# If stream ID not in mapping, create it
|
||||
if stream_id and stream_id not in self._request_streams:
|
||||
self._request_streams[stream_id] = (
|
||||
anyio.create_memory_object_stream[EventMessage](0)
|
||||
)
|
||||
self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0)
|
||||
msg_reader = self._request_streams[stream_id][1]
|
||||
|
||||
# Forward messages to SSE
|
||||
@@ -829,12 +771,8 @@ class StreamableHTTPServerTransport:
|
||||
|
||||
# Create the memory streams for this connection
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream[
|
||||
SessionMessage | Exception
|
||||
](0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream[
|
||||
SessionMessage
|
||||
](0)
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
|
||||
|
||||
# Store the streams
|
||||
self._read_stream_writer = read_stream_writer
|
||||
@@ -867,35 +805,24 @@ class StreamableHTTPServerTransport:
|
||||
session_message.metadata,
|
||||
ServerMessageMetadata,
|
||||
)
|
||||
and session_message.metadata.related_request_id
|
||||
is not None
|
||||
and session_message.metadata.related_request_id is not None
|
||||
):
|
||||
target_request_id = str(
|
||||
session_message.metadata.related_request_id
|
||||
)
|
||||
target_request_id = str(session_message.metadata.related_request_id)
|
||||
|
||||
request_stream_id = (
|
||||
target_request_id
|
||||
if target_request_id is not None
|
||||
else GET_STREAM_KEY
|
||||
)
|
||||
request_stream_id = target_request_id if target_request_id is not None else GET_STREAM_KEY
|
||||
|
||||
# Store the event if we have an event store,
|
||||
# regardless of whether a client is connected
|
||||
# messages will be replayed on the re-connect
|
||||
event_id = None
|
||||
if self._event_store:
|
||||
event_id = await self._event_store.store_event(
|
||||
request_stream_id, message
|
||||
)
|
||||
event_id = await self._event_store.store_event(request_stream_id, message)
|
||||
logger.debug(f"Stored {event_id} from {request_stream_id}")
|
||||
|
||||
if request_stream_id in self._request_streams:
|
||||
try:
|
||||
# Send both the message and the event ID
|
||||
await self._request_streams[request_stream_id][0].send(
|
||||
EventMessage(message, event_id)
|
||||
)
|
||||
await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id))
|
||||
except (
|
||||
anyio.BrokenResourceError,
|
||||
anyio.ClosedResourceError,
|
||||
|
||||
@@ -165,9 +165,7 @@ class StreamableHTTPSessionManager:
|
||||
)
|
||||
|
||||
# Start server in a new task
|
||||
async def run_stateless_server(
|
||||
*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED
|
||||
):
|
||||
async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED):
|
||||
async with http_transport.connect() as streams:
|
||||
read_stream, write_stream = streams
|
||||
task_status.started()
|
||||
@@ -204,10 +202,7 @@ class StreamableHTTPSessionManager:
|
||||
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
|
||||
|
||||
# Existing session case
|
||||
if (
|
||||
request_mcp_session_id is not None
|
||||
and request_mcp_session_id in self._server_instances
|
||||
):
|
||||
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances:
|
||||
transport = self._server_instances[request_mcp_session_id]
|
||||
logger.debug("Session already exists, handling request directly")
|
||||
await transport.handle_request(scope, receive, send)
|
||||
@@ -229,9 +224,7 @@ class StreamableHTTPSessionManager:
|
||||
logger.info(f"Created new transport with session ID: {new_session_id}")
|
||||
|
||||
# Define the server runner
|
||||
async def run_server(
|
||||
*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED
|
||||
) -> None:
|
||||
async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None:
|
||||
async with http_transport.connect() as streams:
|
||||
read_stream, write_stream = streams
|
||||
task_status.started()
|
||||
|
||||
@@ -93,12 +93,8 @@ class StreamingASGITransport(AsyncBaseTransport):
|
||||
initial_response_ready = anyio.Event()
|
||||
|
||||
# Synchronization for streaming response
|
||||
asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[
|
||||
dict[str, Any]
|
||||
](100)
|
||||
content_send_channel, content_receive_channel = (
|
||||
anyio.create_memory_object_stream[bytes](100)
|
||||
)
|
||||
asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[dict[str, Any]](100)
|
||||
content_send_channel, content_receive_channel = anyio.create_memory_object_stream[bytes](100)
|
||||
|
||||
# ASGI callables.
|
||||
async def receive() -> dict[str, Any]:
|
||||
@@ -124,21 +120,15 @@ class StreamingASGITransport(AsyncBaseTransport):
|
||||
async def run_app() -> None:
|
||||
try:
|
||||
# Cast the receive and send functions to the ASGI types
|
||||
await self.app(
|
||||
cast(Scope, scope), cast(Receive, receive), cast(Send, send)
|
||||
)
|
||||
await self.app(cast(Scope, scope), cast(Receive, receive), cast(Send, send))
|
||||
except Exception:
|
||||
if self.raise_app_exceptions:
|
||||
raise
|
||||
|
||||
if not response_started:
|
||||
await asgi_send_channel.send(
|
||||
{"type": "http.response.start", "status": 500, "headers": []}
|
||||
)
|
||||
await asgi_send_channel.send({"type": "http.response.start", "status": 500, "headers": []})
|
||||
|
||||
await asgi_send_channel.send(
|
||||
{"type": "http.response.body", "body": b"", "more_body": False}
|
||||
)
|
||||
await asgi_send_channel.send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
finally:
|
||||
await asgi_send_channel.aclose()
|
||||
|
||||
|
||||
@@ -51,9 +51,7 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
obj = session_message.message.model_dump_json(
|
||||
by_alias=True, exclude_none=True
|
||||
)
|
||||
obj = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
await websocket.send_text(obj)
|
||||
except anyio.ClosedResourceError:
|
||||
await websocket.close()
|
||||
|
||||
Reference in New Issue
Block a user