Use 120 characters instead of 88 (#856)

This commit is contained in:
Marcelo Trylesinski
2025-06-11 02:45:50 -07:00
committed by GitHub
parent f7265f7b91
commit 543961968c
90 changed files with 687 additions and 2142 deletions

View File

@@ -47,18 +47,14 @@ mcp = FastMCP(
DB_DSN = "postgresql://postgres:postgres@localhost:54320/memory_db" DB_DSN = "postgresql://postgres:postgres@localhost:54320/memory_db"
# reset memory with rm ~/.fastmcp/{USER}/memory/* # reset memory with rm ~/.fastmcp/{USER}/memory/*
PROFILE_DIR = ( PROFILE_DIR = (Path.home() / ".fastmcp" / os.environ.get("USER", "anon") / "memory").resolve()
Path.home() / ".fastmcp" / os.environ.get("USER", "anon") / "memory"
).resolve()
PROFILE_DIR.mkdir(parents=True, exist_ok=True) PROFILE_DIR.mkdir(parents=True, exist_ok=True)
def cosine_similarity(a: list[float], b: list[float]) -> float: def cosine_similarity(a: list[float], b: list[float]) -> float:
a_array = np.array(a, dtype=np.float64) a_array = np.array(a, dtype=np.float64)
b_array = np.array(b, dtype=np.float64) b_array = np.array(b, dtype=np.float64)
return np.dot(a_array, b_array) / ( return np.dot(a_array, b_array) / (np.linalg.norm(a_array) * np.linalg.norm(b_array))
np.linalg.norm(a_array) * np.linalg.norm(b_array)
)
async def do_ai[T]( async def do_ai[T](
@@ -97,9 +93,7 @@ class MemoryNode(BaseModel):
summary: str = "" summary: str = ""
importance: float = 1.0 importance: float = 1.0
access_count: int = 0 access_count: int = 0
timestamp: float = Field( timestamp: float = Field(default_factory=lambda: datetime.now(timezone.utc).timestamp())
default_factory=lambda: datetime.now(timezone.utc).timestamp()
)
embedding: list[float] embedding: list[float]
@classmethod @classmethod
@@ -152,9 +146,7 @@ class MemoryNode(BaseModel):
self.importance += other.importance self.importance += other.importance
self.access_count += other.access_count self.access_count += other.access_count
self.embedding = [(a + b) / 2 for a, b in zip(self.embedding, other.embedding)] self.embedding = [(a + b) / 2 for a, b in zip(self.embedding, other.embedding)]
self.summary = await do_ai( self.summary = await do_ai(self.content, "Summarize the following text concisely.", str, deps)
self.content, "Summarize the following text concisely.", str, deps
)
await self.save(deps) await self.save(deps)
# Delete the merged node from the database # Delete the merged node from the database
if other.id is not None: if other.id is not None:
@@ -221,9 +213,7 @@ async def find_similar_memories(embedding: list[float], deps: Deps) -> list[Memo
async def update_importance(user_embedding: list[float], deps: Deps): async def update_importance(user_embedding: list[float], deps: Deps):
async with deps.pool.acquire() as conn: async with deps.pool.acquire() as conn:
rows = await conn.fetch( rows = await conn.fetch("SELECT id, importance, access_count, embedding FROM memories")
"SELECT id, importance, access_count, embedding FROM memories"
)
for row in rows: for row in rows:
memory_embedding = row["embedding"] memory_embedding = row["embedding"]
similarity = cosine_similarity(user_embedding, memory_embedding) similarity = cosine_similarity(user_embedding, memory_embedding)
@@ -273,9 +263,7 @@ async def display_memory_tree(deps: Deps) -> str:
) )
result = "" result = ""
for row in rows: for row in rows:
effective_importance = row["importance"] * ( effective_importance = row["importance"] * (1 + math.log(row["access_count"] + 1))
1 + math.log(row["access_count"] + 1)
)
summary = row["summary"] or row["content"] summary = row["summary"] or row["content"]
result += f"- {summary} (Importance: {effective_importance:.2f})\n" result += f"- {summary} (Importance: {effective_importance:.2f})\n"
return result return result
@@ -283,15 +271,11 @@ async def display_memory_tree(deps: Deps) -> str:
@mcp.tool() @mcp.tool()
async def remember( async def remember(
contents: list[str] = Field( contents: list[str] = Field(description="List of observations or memories to store"),
description="List of observations or memories to store"
),
): ):
deps = Deps(openai=AsyncOpenAI(), pool=await get_db_pool()) deps = Deps(openai=AsyncOpenAI(), pool=await get_db_pool())
try: try:
return "\n".join( return "\n".join(await asyncio.gather(*[add_memory(content, deps) for content in contents]))
await asyncio.gather(*[add_memory(content, deps) for content in contents])
)
finally: finally:
await deps.pool.close() await deps.pool.close()
@@ -305,9 +289,7 @@ async def read_profile() -> str:
async def initialize_database(): async def initialize_database():
pool = await asyncpg.create_pool( pool = await asyncpg.create_pool("postgresql://postgres:postgres@localhost:54320/postgres")
"postgresql://postgres:postgres@localhost:54320/postgres"
)
try: try:
async with pool.acquire() as conn: async with pool.acquire() as conn:
await conn.execute(""" await conn.execute("""

View File

@@ -28,15 +28,11 @@ from mcp.server.fastmcp import FastMCP
class SurgeSettings(BaseSettings): class SurgeSettings(BaseSettings):
model_config: SettingsConfigDict = SettingsConfigDict( model_config: SettingsConfigDict = SettingsConfigDict(env_prefix="SURGE_", env_file=".env")
env_prefix="SURGE_", env_file=".env"
)
api_key: str api_key: str
account_id: str account_id: str
my_phone_number: Annotated[ my_phone_number: Annotated[str, BeforeValidator(lambda v: "+" + v if not v.startswith("+") else v)]
str, BeforeValidator(lambda v: "+" + v if not v.startswith("+") else v)
]
my_first_name: str my_first_name: str
my_last_name: str my_last_name: str

View File

@@ -8,10 +8,7 @@ from mcp.server.fastmcp import FastMCP
mcp = FastMCP() mcp = FastMCP()
@mcp.tool( @mcp.tool(description="🌟 A tool that uses various Unicode characters in its description: " "á é í ó ú ñ 漢字 🎉")
description="🌟 A tool that uses various Unicode characters in its description: "
"á é í ó ú ñ 漢字 🎉"
)
def hello_unicode(name: str = "世界", greeting: str = "¡Hola") -> str: def hello_unicode(name: str = "世界", greeting: str = "¡Hola") -> str:
""" """
A simple tool that demonstrates Unicode handling in: A simple tool that demonstrates Unicode handling in:

View File

@@ -82,9 +82,7 @@ class SimpleGitHubOAuthProvider(OAuthAuthorizationServerProvider):
"""Register a new OAuth client.""" """Register a new OAuth client."""
self.clients[client_info.client_id] = client_info self.clients[client_info.client_id] = client_info
async def authorize( async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
self, client: OAuthClientInformationFull, params: AuthorizationParams
) -> str:
"""Generate an authorization URL for GitHub OAuth flow.""" """Generate an authorization URL for GitHub OAuth flow."""
state = params.state or secrets.token_hex(16) state = params.state or secrets.token_hex(16)
@@ -92,9 +90,7 @@ class SimpleGitHubOAuthProvider(OAuthAuthorizationServerProvider):
self.state_mapping[state] = { self.state_mapping[state] = {
"redirect_uri": str(params.redirect_uri), "redirect_uri": str(params.redirect_uri),
"code_challenge": params.code_challenge, "code_challenge": params.code_challenge,
"redirect_uri_provided_explicitly": str( "redirect_uri_provided_explicitly": str(params.redirect_uri_provided_explicitly),
params.redirect_uri_provided_explicitly
),
"client_id": client.client_id, "client_id": client.client_id,
} }
@@ -117,9 +113,7 @@ class SimpleGitHubOAuthProvider(OAuthAuthorizationServerProvider):
redirect_uri = state_data["redirect_uri"] redirect_uri = state_data["redirect_uri"]
code_challenge = state_data["code_challenge"] code_challenge = state_data["code_challenge"]
redirect_uri_provided_explicitly = ( redirect_uri_provided_explicitly = state_data["redirect_uri_provided_explicitly"] == "True"
state_data["redirect_uri_provided_explicitly"] == "True"
)
client_id = state_data["client_id"] client_id = state_data["client_id"]
# Exchange code for token with GitHub # Exchange code for token with GitHub
@@ -200,8 +194,7 @@ class SimpleGitHubOAuthProvider(OAuthAuthorizationServerProvider):
for token, data in self.tokens.items() for token, data in self.tokens.items()
# see https://github.blog/engineering/platform-security/behind-githubs-new-authentication-token-formats/ # see https://github.blog/engineering/platform-security/behind-githubs-new-authentication-token-formats/
# which you get depends on your GH app setup. # which you get depends on your GH app setup.
if (token.startswith("ghu_") or token.startswith("gho_")) if (token.startswith("ghu_") or token.startswith("gho_")) and data.client_id == client.client_id
and data.client_id == client.client_id
), ),
None, None,
) )
@@ -232,9 +225,7 @@ class SimpleGitHubOAuthProvider(OAuthAuthorizationServerProvider):
return access_token return access_token
async def load_refresh_token( async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None:
self, client: OAuthClientInformationFull, refresh_token: str
) -> RefreshToken | None:
"""Load a refresh token - not supported.""" """Load a refresh token - not supported."""
return None return None
@@ -247,9 +238,7 @@ class SimpleGitHubOAuthProvider(OAuthAuthorizationServerProvider):
"""Exchange refresh token""" """Exchange refresh token"""
raise NotImplementedError("Not supported") raise NotImplementedError("Not supported")
async def revoke_token( async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None:
self, token: str, token_type_hint: str | None = None
) -> None:
"""Revoke a token.""" """Revoke a token."""
if token in self.tokens: if token in self.tokens:
del self.tokens[token] del self.tokens[token]
@@ -335,9 +324,7 @@ def create_simple_mcp_server(settings: ServerSettings) -> FastMCP:
) )
if response.status_code != 200: if response.status_code != 200:
raise ValueError( raise ValueError(f"GitHub API error: {response.status_code} - {response.text}")
f"GitHub API error: {response.status_code} - {response.text}"
)
return response.json() return response.json()
@@ -361,9 +348,7 @@ def main(port: int, host: str, transport: Literal["sse", "streamable-http"]) ->
# No hardcoded credentials - all from environment variables # No hardcoded credentials - all from environment variables
settings = ServerSettings(host=host, port=port) settings = ServerSettings(host=host, port=port)
except ValueError as e: except ValueError as e:
logger.error( logger.error("Failed to load settings. Make sure environment variables are set:")
"Failed to load settings. Make sure environment variables are set:"
)
logger.error(" MCP_GITHUB_GITHUB_CLIENT_ID=<your-client-id>") logger.error(" MCP_GITHUB_GITHUB_CLIENT_ID=<your-client-id>")
logger.error(" MCP_GITHUB_GITHUB_CLIENT_SECRET=<your-client-secret>") logger.error(" MCP_GITHUB_GITHUB_CLIENT_SECRET=<your-client-secret>")
logger.error(f"Error: {e}") logger.error(f"Error: {e}")

View File

@@ -96,7 +96,7 @@ select = ["C4", "E", "F", "I", "PERF", "UP"]
ignore = ["PERF203"] ignore = ["PERF203"]
[tool.ruff] [tool.ruff]
line-length = 88 line-length = 120
target-version = "py310" target-version = "py310"
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]

View File

@@ -21,9 +21,7 @@ def get_claude_config_path() -> Path | None:
elif sys.platform == "darwin": elif sys.platform == "darwin":
path = Path(Path.home(), "Library", "Application Support", "Claude") path = Path(Path.home(), "Library", "Application Support", "Claude")
elif sys.platform.startswith("linux"): elif sys.platform.startswith("linux"):
path = Path( path = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config"), "Claude")
os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config"), "Claude"
)
else: else:
return None return None
@@ -37,8 +35,7 @@ def get_uv_path() -> str:
uv_path = shutil.which("uv") uv_path = shutil.which("uv")
if not uv_path: if not uv_path:
logger.error( logger.error(
"uv executable not found in PATH, falling back to 'uv'. " "uv executable not found in PATH, falling back to 'uv'. " "Please ensure uv is installed and in your PATH"
"Please ensure uv is installed and in your PATH"
) )
return "uv" # Fall back to just "uv" if not found return "uv" # Fall back to just "uv" if not found
return uv_path return uv_path
@@ -94,10 +91,7 @@ def update_claude_config(
config["mcpServers"] = {} config["mcpServers"] = {}
# Always preserve existing env vars and merge with new ones # Always preserve existing env vars and merge with new ones
if ( if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]:
server_name in config["mcpServers"]
and "env" in config["mcpServers"][server_name]
):
existing_env = config["mcpServers"][server_name]["env"] existing_env = config["mcpServers"][server_name]["env"]
if env_vars: if env_vars:
# New vars take precedence over existing ones # New vars take precedence over existing ones

View File

@@ -45,9 +45,7 @@ def _get_npx_command():
# Try both npx.cmd and npx.exe on Windows # Try both npx.cmd and npx.exe on Windows
for cmd in ["npx.cmd", "npx.exe", "npx"]: for cmd in ["npx.cmd", "npx.exe", "npx"]:
try: try:
subprocess.run( subprocess.run([cmd, "--version"], check=True, capture_output=True, shell=True)
[cmd, "--version"], check=True, capture_output=True, shell=True
)
return cmd return cmd
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
continue continue
@@ -58,9 +56,7 @@ def _get_npx_command():
def _parse_env_var(env_var: str) -> tuple[str, str]: def _parse_env_var(env_var: str) -> tuple[str, str]:
"""Parse environment variable string in format KEY=VALUE.""" """Parse environment variable string in format KEY=VALUE."""
if "=" not in env_var: if "=" not in env_var:
logger.error( logger.error(f"Invalid environment variable format: {env_var}. Must be KEY=VALUE")
f"Invalid environment variable format: {env_var}. Must be KEY=VALUE"
)
sys.exit(1) sys.exit(1)
key, value = env_var.split("=", 1) key, value = env_var.split("=", 1)
return key.strip(), value.strip() return key.strip(), value.strip()
@@ -154,14 +150,10 @@ def _import_server(file: Path, server_object: str | None = None):
True if it's supported. True if it's supported.
""" """
if not isinstance(server_object, FastMCP): if not isinstance(server_object, FastMCP):
logger.error( logger.error(f"The server object {object_name} is of type " f"{type(server_object)} (expecting {FastMCP}).")
f"The server object {object_name} is of type "
f"{type(server_object)} (expecting {FastMCP})."
)
if isinstance(server_object, LowLevelServer): if isinstance(server_object, LowLevelServer):
logger.warning( logger.warning(
"Note that only FastMCP server is supported. Low level " "Note that only FastMCP server is supported. Low level " "Server class is not yet supported."
"Server class is not yet supported."
) )
return False return False
return True return True
@@ -172,10 +164,7 @@ def _import_server(file: Path, server_object: str | None = None):
for name in ["mcp", "server", "app"]: for name in ["mcp", "server", "app"]:
if hasattr(module, name): if hasattr(module, name):
if not _check_server_object(getattr(module, name), f"{file}:{name}"): if not _check_server_object(getattr(module, name), f"{file}:{name}"):
logger.error( logger.error(f"Ignoring object '{file}:{name}' as it's not a valid " "server object")
f"Ignoring object '{file}:{name}' as it's not a valid "
"server object"
)
continue continue
return getattr(module, name) return getattr(module, name)
@@ -280,8 +269,7 @@ def dev(
npx_cmd = _get_npx_command() npx_cmd = _get_npx_command()
if not npx_cmd: if not npx_cmd:
logger.error( logger.error(
"npx not found. Please ensure Node.js and npm are properly installed " "npx not found. Please ensure Node.js and npm are properly installed " "and added to your system PATH."
"and added to your system PATH."
) )
sys.exit(1) sys.exit(1)
@@ -383,8 +371,7 @@ def install(
typer.Option( typer.Option(
"--name", "--name",
"-n", "-n",
help="Custom name for the server (defaults to server's name attribute or" help="Custom name for the server (defaults to server's name attribute or" " file name)",
" file name)",
), ),
] = None, ] = None,
with_editable: Annotated[ with_editable: Annotated[
@@ -458,8 +445,7 @@ def install(
name = server.name name = server.name
except (ImportError, ModuleNotFoundError) as e: except (ImportError, ModuleNotFoundError) as e:
logger.debug( logger.debug(
"Could not import server (likely missing dependencies), using file" "Could not import server (likely missing dependencies), using file" " name",
" name",
extra={"error": str(e)}, extra={"error": str(e)},
) )
name = file.stem name = file.stem
@@ -477,11 +463,7 @@ def install(
if env_file: if env_file:
if dotenv: if dotenv:
try: try:
env_dict |= { env_dict |= {k: v for k, v in dotenv.dotenv_values(env_file).items() if v is not None}
k: v
for k, v in dotenv.dotenv_values(env_file).items()
if v is not None
}
except Exception as e: except Exception as e:
logger.error(f"Failed to load .env file: {e}") logger.error(f"Failed to load .env file: {e}")
sys.exit(1) sys.exit(1)

View File

@@ -24,9 +24,7 @@ logger = logging.getLogger("client")
async def message_handler( async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
| types.ServerNotification
| Exception,
) -> None: ) -> None:
if isinstance(message, Exception): if isinstance(message, Exception):
logger.error("Error: %s", message) logger.error("Error: %s", message)
@@ -60,9 +58,7 @@ async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]])
await run_session(*streams) await run_session(*streams)
else: else:
# Use stdio client for commands # Use stdio client for commands
server_parameters = StdioServerParameters( server_parameters = StdioServerParameters(command=command_or_url, args=args, env=env_dict)
command=command_or_url, args=args, env=env_dict
)
async with stdio_client(server_parameters) as streams: async with stdio_client(server_parameters) as streams:
await run_session(*streams) await run_session(*streams)

View File

@@ -17,12 +17,7 @@ from urllib.parse import urlencode, urljoin
import anyio import anyio
import httpx import httpx
from mcp.shared.auth import ( from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata, OAuthToken
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthToken,
)
from mcp.types import LATEST_PROTOCOL_VERSION from mcp.types import LATEST_PROTOCOL_VERSION
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -100,10 +95,7 @@ class OAuthClientProvider(httpx.Auth):
def _generate_code_verifier(self) -> str: def _generate_code_verifier(self) -> str:
"""Generate a cryptographically random code verifier for PKCE.""" """Generate a cryptographically random code verifier for PKCE."""
return "".join( return "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128))
secrets.choice(string.ascii_letters + string.digits + "-._~")
for _ in range(128)
)
def _generate_code_challenge(self, code_verifier: str) -> str: def _generate_code_challenge(self, code_verifier: str) -> str:
"""Generate a code challenge from a code verifier using SHA256.""" """Generate a code challenge from a code verifier using SHA256."""
@@ -148,9 +140,7 @@ class OAuthClientProvider(httpx.Auth):
return None return None
response.raise_for_status() response.raise_for_status()
metadata_json = response.json() metadata_json = response.json()
logger.debug( logger.debug(f"OAuth metadata discovered (no MCP header): {metadata_json}")
f"OAuth metadata discovered (no MCP header): {metadata_json}"
)
return OAuthMetadata.model_validate(metadata_json) return OAuthMetadata.model_validate(metadata_json)
except Exception: except Exception:
logger.exception("Failed to discover OAuth metadata") logger.exception("Failed to discover OAuth metadata")
@@ -176,17 +166,11 @@ class OAuthClientProvider(httpx.Auth):
registration_url = urljoin(auth_base_url, "/register") registration_url = urljoin(auth_base_url, "/register")
# Handle default scope # Handle default scope
if ( if client_metadata.scope is None and metadata and metadata.scopes_supported is not None:
client_metadata.scope is None
and metadata
and metadata.scopes_supported is not None
):
client_metadata.scope = " ".join(metadata.scopes_supported) client_metadata.scope = " ".join(metadata.scopes_supported)
# Serialize client metadata # Serialize client metadata
registration_data = client_metadata.model_dump( registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
by_alias=True, mode="json", exclude_none=True
)
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try: try:
@@ -213,9 +197,7 @@ class OAuthClientProvider(httpx.Auth):
logger.exception("Registration error") logger.exception("Registration error")
raise raise
async def async_auth_flow( async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
self, request: httpx.Request
) -> AsyncGenerator[httpx.Request, httpx.Response]:
""" """
HTTPX auth flow integration. HTTPX auth flow integration.
""" """
@@ -225,9 +207,7 @@ class OAuthClientProvider(httpx.Auth):
await self.ensure_token() await self.ensure_token()
# Add Bearer token if available # Add Bearer token if available
if self._current_tokens and self._current_tokens.access_token: if self._current_tokens and self._current_tokens.access_token:
request.headers["Authorization"] = ( request.headers["Authorization"] = f"Bearer {self._current_tokens.access_token}"
f"Bearer {self._current_tokens.access_token}"
)
response = yield request response = yield request
@@ -305,11 +285,7 @@ class OAuthClientProvider(httpx.Auth):
return return
# Try refreshing existing token # Try refreshing existing token
if ( if self._current_tokens and self._current_tokens.refresh_token and await self._refresh_access_token():
self._current_tokens
and self._current_tokens.refresh_token
and await self._refresh_access_token()
):
return return
# Fall back to full OAuth flow # Fall back to full OAuth flow
@@ -361,12 +337,8 @@ class OAuthClientProvider(httpx.Auth):
auth_code, returned_state = await self.callback_handler() auth_code, returned_state = await self.callback_handler()
# Validate state parameter for CSRF protection # Validate state parameter for CSRF protection
if returned_state is None or not secrets.compare_digest( if returned_state is None or not secrets.compare_digest(returned_state, self._auth_state):
returned_state, self._auth_state raise Exception(f"State parameter mismatch: {returned_state} != {self._auth_state}")
):
raise Exception(
f"State parameter mismatch: {returned_state} != {self._auth_state}"
)
# Clear state after validation # Clear state after validation
self._auth_state = None self._auth_state = None
@@ -377,9 +349,7 @@ class OAuthClientProvider(httpx.Auth):
# Exchange authorization code for tokens # Exchange authorization code for tokens
await self._exchange_code_for_token(auth_code, client_info) await self._exchange_code_for_token(auth_code, client_info)
async def _exchange_code_for_token( async def _exchange_code_for_token(self, auth_code: str, client_info: OAuthClientInformationFull) -> None:
self, auth_code: str, client_info: OAuthClientInformationFull
) -> None:
"""Exchange authorization code for access token.""" """Exchange authorization code for access token."""
# Get token endpoint # Get token endpoint
if self._metadata and self._metadata.token_endpoint: if self._metadata and self._metadata.token_endpoint:
@@ -412,17 +382,10 @@ class OAuthClientProvider(httpx.Auth):
# Parse OAuth error response # Parse OAuth error response
try: try:
error_data = response.json() error_data = response.json()
error_msg = error_data.get( error_msg = error_data.get("error_description", error_data.get("error", "Unknown error"))
"error_description", error_data.get("error", "Unknown error") raise Exception(f"Token exchange failed: {error_msg} " f"(HTTP {response.status_code})")
)
raise Exception(
f"Token exchange failed: {error_msg} "
f"(HTTP {response.status_code})"
)
except Exception: except Exception:
raise Exception( raise Exception(f"Token exchange failed: {response.status_code} {response.text}")
f"Token exchange failed: {response.status_code} {response.text}"
)
# Parse token response # Parse token response
token_response = OAuthToken.model_validate(response.json()) token_response = OAuthToken.model_validate(response.json())

View File

@@ -38,16 +38,12 @@ class LoggingFnT(Protocol):
class MessageHandlerFnT(Protocol): class MessageHandlerFnT(Protocol):
async def __call__( async def __call__(
self, self,
message: RequestResponder[types.ServerRequest, types.ClientResult] message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
| types.ServerNotification
| Exception,
) -> None: ... ) -> None: ...
async def _default_message_handler( async def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
| types.ServerNotification
| Exception,
) -> None: ) -> None:
await anyio.lowlevel.checkpoint() await anyio.lowlevel.checkpoint()
@@ -77,9 +73,7 @@ async def _default_logging_callback(
pass pass
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter( ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
types.ClientResult | types.ErrorData
)
class ClientSession( class ClientSession(
@@ -116,11 +110,7 @@ class ClientSession(
self._message_handler = message_handler or _default_message_handler self._message_handler = message_handler or _default_message_handler
async def initialize(self) -> types.InitializeResult: async def initialize(self) -> types.InitializeResult:
sampling = ( sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
types.SamplingCapability()
if self._sampling_callback is not _default_sampling_callback
else None
)
roots = ( roots = (
# TODO: Should this be based on whether we # TODO: Should this be based on whether we
# _will_ send notifications, or only whether # _will_ send notifications, or only whether
@@ -149,15 +139,10 @@ class ClientSession(
) )
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
raise RuntimeError( raise RuntimeError("Unsupported protocol version from the server: " f"{result.protocolVersion}")
"Unsupported protocol version from the server: "
f"{result.protocolVersion}"
)
await self.send_notification( await self.send_notification(
types.ClientNotification( types.ClientNotification(types.InitializedNotification(method="notifications/initialized"))
types.InitializedNotification(method="notifications/initialized")
)
) )
return result return result
@@ -207,33 +192,25 @@ class ClientSession(
types.EmptyResult, types.EmptyResult,
) )
async def list_resources( async def list_resources(self, cursor: str | None = None) -> types.ListResourcesResult:
self, cursor: str | None = None
) -> types.ListResourcesResult:
"""Send a resources/list request.""" """Send a resources/list request."""
return await self.send_request( return await self.send_request(
types.ClientRequest( types.ClientRequest(
types.ListResourcesRequest( types.ListResourcesRequest(
method="resources/list", method="resources/list",
params=types.PaginatedRequestParams(cursor=cursor) params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None,
if cursor is not None
else None,
) )
), ),
types.ListResourcesResult, types.ListResourcesResult,
) )
async def list_resource_templates( async def list_resource_templates(self, cursor: str | None = None) -> types.ListResourceTemplatesResult:
self, cursor: str | None = None
) -> types.ListResourceTemplatesResult:
"""Send a resources/templates/list request.""" """Send a resources/templates/list request."""
return await self.send_request( return await self.send_request(
types.ClientRequest( types.ClientRequest(
types.ListResourceTemplatesRequest( types.ListResourceTemplatesRequest(
method="resources/templates/list", method="resources/templates/list",
params=types.PaginatedRequestParams(cursor=cursor) params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None,
if cursor is not None
else None,
) )
), ),
types.ListResourceTemplatesResult, types.ListResourceTemplatesResult,
@@ -305,17 +282,13 @@ class ClientSession(
types.ClientRequest( types.ClientRequest(
types.ListPromptsRequest( types.ListPromptsRequest(
method="prompts/list", method="prompts/list",
params=types.PaginatedRequestParams(cursor=cursor) params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None,
if cursor is not None
else None,
) )
), ),
types.ListPromptsResult, types.ListPromptsResult,
) )
async def get_prompt( async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
self, name: str, arguments: dict[str, str] | None = None
) -> types.GetPromptResult:
"""Send a prompts/get request.""" """Send a prompts/get request."""
return await self.send_request( return await self.send_request(
types.ClientRequest( types.ClientRequest(
@@ -352,9 +325,7 @@ class ClientSession(
types.ClientRequest( types.ClientRequest(
types.ListToolsRequest( types.ListToolsRequest(
method="tools/list", method="tools/list",
params=types.PaginatedRequestParams(cursor=cursor) params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None,
if cursor is not None
else None,
) )
), ),
types.ListToolsResult, types.ListToolsResult,
@@ -370,9 +341,7 @@ class ClientSession(
) )
) )
async def _received_request( async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
) -> None:
ctx = RequestContext[ClientSession, Any]( ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id, request_id=responder.request_id,
meta=responder.request_meta, meta=responder.request_meta,
@@ -395,22 +364,16 @@ class ClientSession(
case types.PingRequest(): case types.PingRequest():
with responder: with responder:
return await responder.respond( return await responder.respond(types.ClientResult(root=types.EmptyResult()))
types.ClientResult(root=types.EmptyResult())
)
async def _handle_incoming( async def _handle_incoming(
self, self,
req: RequestResponder[types.ServerRequest, types.ClientResult] req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
| types.ServerNotification
| Exception,
) -> None: ) -> None:
"""Handle incoming messages by forwarding to the message handler.""" """Handle incoming messages by forwarding to the message handler."""
await self._message_handler(req) await self._message_handler(req)
async def _received_notification( async def _received_notification(self, notification: types.ServerNotification) -> None:
self, notification: types.ServerNotification
) -> None:
"""Handle notifications from the server.""" """Handle notifications from the server."""
# Process specific notification types # Process specific notification types
match notification.root: match notification.root:

View File

@@ -62,9 +62,7 @@ class StreamableHttpParameters(BaseModel):
terminate_on_close: bool = True terminate_on_close: bool = True
ServerParameters: TypeAlias = ( ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
StdioServerParameters | SseServerParameters | StreamableHttpParameters
)
class ClientSessionGroup: class ClientSessionGroup:
@@ -261,9 +259,7 @@ class ClientSessionGroup:
) )
read, write, _ = await session_stack.enter_async_context(client) read, write, _ = await session_stack.enter_async_context(client)
session = await session_stack.enter_async_context( session = await session_stack.enter_async_context(mcp.ClientSession(read, write))
mcp.ClientSession(read, write)
)
result = await session.initialize() result = await session.initialize()
# Session successfully initialized. # Session successfully initialized.
@@ -280,9 +276,7 @@ class ClientSessionGroup:
await session_stack.aclose() await session_stack.aclose()
raise raise
async def _aggregate_components( async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None:
self, server_info: types.Implementation, session: mcp.ClientSession
) -> None:
"""Aggregates prompts, resources, and tools from a given session.""" """Aggregates prompts, resources, and tools from a given session."""
# Create a reverse index so we can find all prompts, resources, and # Create a reverse index so we can find all prompts, resources, and

View File

@@ -73,20 +73,16 @@ async def sse_client(
match sse.event: match sse.event:
case "endpoint": case "endpoint":
endpoint_url = urljoin(url, sse.data) endpoint_url = urljoin(url, sse.data)
logger.debug( logger.debug(f"Received endpoint URL: {endpoint_url}")
f"Received endpoint URL: {endpoint_url}"
)
url_parsed = urlparse(url) url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url) endpoint_parsed = urlparse(endpoint_url)
if ( if (
url_parsed.netloc != endpoint_parsed.netloc url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme or url_parsed.scheme != endpoint_parsed.scheme
!= endpoint_parsed.scheme
): ):
error_msg = ( error_msg = (
"Endpoint origin does not match " "Endpoint origin does not match " f"connection origin: {endpoint_url}"
f"connection origin: {endpoint_url}"
) )
logger.error(error_msg) logger.error(error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
@@ -98,22 +94,16 @@ async def sse_client(
message = types.JSONRPCMessage.model_validate_json( # noqa: E501 message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data sse.data
) )
logger.debug( logger.debug(f"Received server message: {message}")
f"Received server message: {message}"
)
except Exception as exc: except Exception as exc:
logger.error( logger.error(f"Error parsing server message: {exc}")
f"Error parsing server message: {exc}"
)
await read_stream_writer.send(exc) await read_stream_writer.send(exc)
continue continue
session_message = SessionMessage(message) session_message = SessionMessage(message)
await read_stream_writer.send(session_message) await read_stream_writer.send(session_message)
case _: case _:
logger.warning( logger.warning(f"Unknown SSE event: {sse.event}")
f"Unknown SSE event: {sse.event}"
)
except Exception as exc: except Exception as exc:
logger.error(f"Error in sse_reader: {exc}") logger.error(f"Error in sse_reader: {exc}")
await read_stream_writer.send(exc) await read_stream_writer.send(exc)
@@ -124,9 +114,7 @@ async def sse_client(
try: try:
async with write_stream_reader: async with write_stream_reader:
async for session_message in write_stream_reader: async for session_message in write_stream_reader:
logger.debug( logger.debug(f"Sending client message: {session_message}")
f"Sending client message: {session_message}"
)
response = await client.post( response = await client.post(
endpoint_url, endpoint_url,
json=session_message.message.model_dump( json=session_message.message.model_dump(
@@ -136,19 +124,14 @@ async def sse_client(
), ),
) )
response.raise_for_status() response.raise_for_status()
logger.debug( logger.debug("Client message sent successfully: " f"{response.status_code}")
"Client message sent successfully: "
f"{response.status_code}"
)
except Exception as exc: except Exception as exc:
logger.error(f"Error in post_writer: {exc}") logger.error(f"Error in post_writer: {exc}")
finally: finally:
await write_stream.aclose() await write_stream.aclose()
endpoint_url = await tg.start(sse_reader) endpoint_url = await tg.start(sse_reader)
logger.debug( logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
f"Starting post writer with endpoint URL: {endpoint_url}"
)
tg.start_soon(post_writer, endpoint_url) tg.start_soon(post_writer, endpoint_url)
try: try:

View File

@@ -115,11 +115,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
process = await _create_platform_compatible_process( process = await _create_platform_compatible_process(
command=command, command=command,
args=server.args, args=server.args,
env=( env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()),
{**get_default_environment(), **server.env}
if server.env is not None
else get_default_environment()
),
errlog=errlog, errlog=errlog,
cwd=server.cwd, cwd=server.cwd,
) )
@@ -163,9 +159,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
try: try:
async with write_stream_reader: async with write_stream_reader:
async for session_message in write_stream_reader: async for session_message in write_stream_reader:
json = session_message.message.model_dump_json( json = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
by_alias=True, exclude_none=True
)
await process.stdin.send( await process.stdin.send(
(json + "\n").encode( (json + "\n").encode(
encoding=server.encoding, encoding=server.encoding,
@@ -229,8 +223,6 @@ async def _create_platform_compatible_process(
if sys.platform == "win32": if sys.platform == "win32":
process = await create_windows_process(command, args, env, errlog, cwd) process = await create_windows_process(command, args, env, errlog, cwd)
else: else:
process = await anyio.open_process( process = await anyio.open_process([command, *args], env=env, stderr=errlog, cwd=cwd)
[command, *args], env=env, stderr=errlog, cwd=cwd
)
return process return process

View File

@@ -82,9 +82,7 @@ async def create_windows_process(
return process return process
except Exception: except Exception:
# Don't raise, let's try to create the process without creation flags # Don't raise, let's try to create the process without creation flags
process = await anyio.open_process( process = await anyio.open_process([command, *args], env=env, stderr=errlog, cwd=cwd)
[command, *args], env=env, stderr=errlog, cwd=cwd
)
return process return process

View File

@@ -106,9 +106,7 @@ class StreamableHTTPTransport:
**self.headers, **self.headers,
} }
def _update_headers_with_session( def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
self, base_headers: dict[str, str]
) -> dict[str, str]:
"""Update headers with session ID if available.""" """Update headers with session ID if available."""
headers = base_headers.copy() headers = base_headers.copy()
if self.session_id: if self.session_id:
@@ -117,17 +115,11 @@ class StreamableHTTPTransport:
def _is_initialization_request(self, message: JSONRPCMessage) -> bool: def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialization request.""" """Check if the message is an initialization request."""
return ( return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
isinstance(message.root, JSONRPCRequest)
and message.root.method == "initialize"
)
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool: def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialized notification.""" """Check if the message is an initialized notification."""
return ( return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
isinstance(message.root, JSONRPCNotification)
and message.root.method == "notifications/initialized"
)
def _maybe_extract_session_id_from_response( def _maybe_extract_session_id_from_response(
self, self,
@@ -153,9 +145,7 @@ class StreamableHTTPTransport:
logger.debug(f"SSE message: {message}") logger.debug(f"SSE message: {message}")
# If this is a response and we have original_request_id, replace it # If this is a response and we have original_request_id, replace it
if original_request_id is not None and isinstance( if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
message.root, JSONRPCResponse | JSONRPCError
):
message.root.id = original_request_id message.root.id = original_request_id
session_message = SessionMessage(message) session_message = SessionMessage(message)
@@ -227,7 +217,8 @@ class StreamableHTTPTransport:
self.url, self.url,
headers=headers, headers=headers,
timeout=httpx.Timeout( timeout=httpx.Timeout(
self.timeout.total_seconds(), read=ctx.sse_read_timeout.total_seconds() self.timeout.total_seconds(),
read=ctx.sse_read_timeout.total_seconds(),
), ),
) as event_source: ) as event_source:
event_source.response.raise_for_status() event_source.response.raise_for_status()
@@ -298,9 +289,7 @@ class StreamableHTTPTransport:
logger.error(f"Error parsing JSON response: {exc}") logger.error(f"Error parsing JSON response: {exc}")
await read_stream_writer.send(exc) await read_stream_writer.send(exc)
async def _handle_sse_response( async def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
self, response: httpx.Response, ctx: RequestContext
) -> None:
"""Handle SSE response from the server.""" """Handle SSE response from the server."""
try: try:
event_source = EventSource(response) event_source = EventSource(response)
@@ -308,11 +297,7 @@ class StreamableHTTPTransport:
is_complete = await self._handle_sse_event( is_complete = await self._handle_sse_event(
sse, sse,
ctx.read_stream_writer, ctx.read_stream_writer,
resumption_callback=( resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
ctx.metadata.on_resumption_token_update
if ctx.metadata
else None
),
) )
# If the SSE event indicates completion, like returning respose/error # If the SSE event indicates completion, like returning respose/error
# break the loop # break the loop
@@ -455,12 +440,8 @@ async def streamablehttp_client(
""" """
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth) transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth)
read_stream_writer, read_stream = anyio.create_memory_object_stream[ read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
SessionMessage | Exception write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[
SessionMessage
](0)
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
try: try:
@@ -476,9 +457,7 @@ async def streamablehttp_client(
) as client: ) as client:
# Define callbacks that need access to tg # Define callbacks that need access to tg
def start_get_stream() -> None: def start_get_stream() -> None:
tg.start_soon( tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
transport.handle_get_stream, client, read_stream_writer
)
tg.start_soon( tg.start_soon(
transport.post_writer, transport.post_writer,

View File

@@ -19,10 +19,7 @@ logger = logging.getLogger(__name__)
async def websocket_client( async def websocket_client(
url: str, url: str,
) -> AsyncGenerator[ ) -> AsyncGenerator[
tuple[ tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]],
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
],
None, None,
]: ]:
""" """
@@ -74,9 +71,7 @@ async def websocket_client(
async with write_stream_reader: async with write_stream_reader:
async for session_message in write_stream_reader: async for session_message in write_stream_reader:
# Convert to a dict, then to JSON # Convert to a dict, then to JSON
msg_dict = session_message.message.model_dump( msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_none=True)
by_alias=True, mode="json", exclude_none=True
)
await ws.send(json.dumps(msg_dict)) await ws.send(json.dumps(msg_dict))
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:

View File

@@ -2,7 +2,4 @@ from pydantic import ValidationError
def stringify_pydantic_error(validation_error: ValidationError) -> str: def stringify_pydantic_error(validation_error: ValidationError) -> str:
return "\n".join( return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors())
f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}"
for e in validation_error.errors()
)

View File

@@ -7,9 +7,7 @@ from starlette.datastructures import FormData, QueryParams
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import RedirectResponse, Response from starlette.responses import RedirectResponse, Response
from mcp.server.auth.errors import ( from mcp.server.auth.errors import stringify_pydantic_error
stringify_pydantic_error,
)
from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.provider import ( from mcp.server.auth.provider import (
AuthorizationErrorCode, AuthorizationErrorCode,
@@ -18,10 +16,7 @@ from mcp.server.auth.provider import (
OAuthAuthorizationServerProvider, OAuthAuthorizationServerProvider,
construct_redirect_uri, construct_redirect_uri,
) )
from mcp.shared.auth import ( from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError
InvalidRedirectUriError,
InvalidScopeError,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -29,23 +24,16 @@ logger = logging.getLogger(__name__)
class AuthorizationRequest(BaseModel): class AuthorizationRequest(BaseModel):
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
client_id: str = Field(..., description="The client ID") client_id: str = Field(..., description="The client ID")
redirect_uri: AnyUrl | None = Field( redirect_uri: AnyUrl | None = Field(None, description="URL to redirect to after authorization")
None, description="URL to redirect to after authorization"
)
# see OAuthClientMetadata; we only support `code` # see OAuthClientMetadata; we only support `code`
response_type: Literal["code"] = Field( response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow")
..., description="Must be 'code' for authorization code flow"
)
code_challenge: str = Field(..., description="PKCE code challenge") code_challenge: str = Field(..., description="PKCE code challenge")
code_challenge_method: Literal["S256"] = Field( code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method, must be S256")
"S256", description="PKCE code challenge method, must be S256"
)
state: str | None = Field(None, description="Optional state parameter") state: str | None = Field(None, description="Optional state parameter")
scope: str | None = Field( scope: str | None = Field(
None, None,
description="Optional scope; if specified, should be " description="Optional scope; if specified, should be " "a space-separated list of scope strings",
"a space-separated list of scope strings",
) )
@@ -57,9 +45,7 @@ class AuthorizationErrorResponse(BaseModel):
state: str | None = None state: str | None = None
def best_effort_extract_string( def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> str | None:
key: str, params: None | FormData | QueryParams
) -> str | None:
if params is None: if params is None:
return None return None
value = params.get(key) value = params.get(key)
@@ -138,9 +124,7 @@ class AuthorizationHandler:
if redirect_uri and client: if redirect_uri and client:
return RedirectResponse( return RedirectResponse(
url=construct_redirect_uri( url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)),
str(redirect_uri), **error_resp.model_dump(exclude_none=True)
),
status_code=302, status_code=302,
headers={"Cache-Control": "no-store"}, headers={"Cache-Control": "no-store"},
) )
@@ -172,9 +156,7 @@ class AuthorizationHandler:
if e["loc"] == ("response_type",) and e["type"] == "literal_error": if e["loc"] == ("response_type",) and e["type"] == "literal_error":
error = "unsupported_response_type" error = "unsupported_response_type"
break break
return await error_response( return await error_response(error, stringify_pydantic_error(validation_error))
error, stringify_pydantic_error(validation_error)
)
# Get client information # Get client information
client = await self.provider.get_client( client = await self.provider.get_client(
@@ -229,16 +211,9 @@ class AuthorizationHandler:
) )
except AuthorizeError as e: except AuthorizeError as e:
# Handle authorization errors as defined in RFC 6749 Section 4.1.2.1 # Handle authorization errors as defined in RFC 6749 Section 4.1.2.1
return await error_response( return await error_response(error=e.error, error_description=e.error_description)
error=e.error,
error_description=e.error_description,
)
except Exception as validation_error: except Exception as validation_error:
# Catch-all for unexpected errors # Catch-all for unexpected errors
logger.exception( logger.exception("Unexpected error in authorization_handler", exc_info=validation_error)
"Unexpected error in authorization_handler", exc_info=validation_error return await error_response(error="server_error", error_description="An unexpected error occurred")
)
return await error_response(
error="server_error", error_description="An unexpected error occurred"
)

View File

@@ -10,11 +10,7 @@ from starlette.responses import Response
from mcp.server.auth.errors import stringify_pydantic_error from mcp.server.auth.errors import stringify_pydantic_error
from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.provider import ( from mcp.server.auth.provider import OAuthAuthorizationServerProvider, RegistrationError, RegistrationErrorCode
OAuthAuthorizationServerProvider,
RegistrationError,
RegistrationErrorCode,
)
from mcp.server.auth.settings import ClientRegistrationOptions from mcp.server.auth.settings import ClientRegistrationOptions
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
@@ -60,9 +56,7 @@ class RegistrationHandler:
if client_metadata.scope is None and self.options.default_scopes is not None: if client_metadata.scope is None and self.options.default_scopes is not None:
client_metadata.scope = " ".join(self.options.default_scopes) client_metadata.scope = " ".join(self.options.default_scopes)
elif ( elif client_metadata.scope is not None and self.options.valid_scopes is not None:
client_metadata.scope is not None and self.options.valid_scopes is not None
):
requested_scopes = set(client_metadata.scope.split()) requested_scopes = set(client_metadata.scope.split())
valid_scopes = set(self.options.valid_scopes) valid_scopes = set(self.options.valid_scopes)
if not requested_scopes.issubset(valid_scopes): if not requested_scopes.issubset(valid_scopes):
@@ -78,8 +72,7 @@ class RegistrationHandler:
return PydanticJSONResponse( return PydanticJSONResponse(
content=RegistrationErrorResponse( content=RegistrationErrorResponse(
error="invalid_client_metadata", error="invalid_client_metadata",
error_description="grant_types must be authorization_code " error_description="grant_types must be authorization_code " "and refresh_token",
"and refresh_token",
), ),
status_code=400, status_code=400,
) )
@@ -122,8 +115,6 @@ class RegistrationHandler:
except RegistrationError as e: except RegistrationError as e:
# Handle registration errors as defined in RFC 7591 Section 3.2.2 # Handle registration errors as defined in RFC 7591 Section 3.2.2
return PydanticJSONResponse( return PydanticJSONResponse(
content=RegistrationErrorResponse( content=RegistrationErrorResponse(error=e.error, error_description=e.error_description),
error=e.error, error_description=e.error_description
),
status_code=400, status_code=400,
) )

View File

@@ -10,15 +10,8 @@ from mcp.server.auth.errors import (
stringify_pydantic_error, stringify_pydantic_error,
) )
from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.middleware.client_auth import ( from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
AuthenticationError, from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, RefreshToken
ClientAuthenticator,
)
from mcp.server.auth.provider import (
AccessToken,
OAuthAuthorizationServerProvider,
RefreshToken,
)
class RevocationRequest(BaseModel): class RevocationRequest(BaseModel):

View File

@@ -7,19 +7,10 @@ from typing import Annotated, Any, Literal
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
from starlette.requests import Request from starlette.requests import Request
from mcp.server.auth.errors import ( from mcp.server.auth.errors import stringify_pydantic_error
stringify_pydantic_error,
)
from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.middleware.client_auth import ( from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
AuthenticationError, from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenError, TokenErrorCode
ClientAuthenticator,
)
from mcp.server.auth.provider import (
OAuthAuthorizationServerProvider,
TokenError,
TokenErrorCode,
)
from mcp.shared.auth import OAuthToken from mcp.shared.auth import OAuthToken
@@ -27,9 +18,7 @@ class AuthorizationCodeRequest(BaseModel):
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
grant_type: Literal["authorization_code"] grant_type: Literal["authorization_code"]
code: str = Field(..., description="The authorization code") code: str = Field(..., description="The authorization code")
redirect_uri: AnyUrl | None = Field( redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize")
None, description="Must be the same as redirect URI provided in /authorize"
)
client_id: str client_id: str
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
client_secret: str | None = None client_secret: str | None = None
@@ -127,8 +116,7 @@ class TokenHandler:
TokenErrorResponse( TokenErrorResponse(
error="unsupported_grant_type", error="unsupported_grant_type",
error_description=( error_description=(
f"Unsupported grant type (supported grant types are " f"Unsupported grant type (supported grant types are " f"{client_info.grant_types})"
f"{client_info.grant_types})"
), ),
) )
) )
@@ -137,9 +125,7 @@ class TokenHandler:
match token_request: match token_request:
case AuthorizationCodeRequest(): case AuthorizationCodeRequest():
auth_code = await self.provider.load_authorization_code( auth_code = await self.provider.load_authorization_code(client_info, token_request.code)
client_info, token_request.code
)
if auth_code is None or auth_code.client_id != token_request.client_id: if auth_code is None or auth_code.client_id != token_request.client_id:
# if code belongs to different client, pretend it doesn't exist # if code belongs to different client, pretend it doesn't exist
return self.response( return self.response(
@@ -169,18 +155,13 @@ class TokenHandler:
return self.response( return self.response(
TokenErrorResponse( TokenErrorResponse(
error="invalid_request", error="invalid_request",
error_description=( error_description=("redirect_uri did not match the one " "used when creating auth code"),
"redirect_uri did not match the one "
"used when creating auth code"
),
) )
) )
# Verify PKCE code verifier # Verify PKCE code verifier
sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest() sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()
hashed_code_verifier = ( hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=")
base64.urlsafe_b64encode(sha256).decode().rstrip("=")
)
if hashed_code_verifier != auth_code.code_challenge: if hashed_code_verifier != auth_code.code_challenge:
# see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6 # see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
@@ -193,9 +174,7 @@ class TokenHandler:
try: try:
# Exchange authorization code for tokens # Exchange authorization code for tokens
tokens = await self.provider.exchange_authorization_code( tokens = await self.provider.exchange_authorization_code(client_info, auth_code)
client_info, auth_code
)
except TokenError as e: except TokenError as e:
return self.response( return self.response(
TokenErrorResponse( TokenErrorResponse(
@@ -205,13 +184,8 @@ class TokenHandler:
) )
case RefreshTokenRequest(): case RefreshTokenRequest():
refresh_token = await self.provider.load_refresh_token( refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token)
client_info, token_request.refresh_token if refresh_token is None or refresh_token.client_id != token_request.client_id:
)
if (
refresh_token is None
or refresh_token.client_id != token_request.client_id
):
# if token belongs to different client, pretend it doesn't exist # if token belongs to different client, pretend it doesn't exist
return self.response( return self.response(
TokenErrorResponse( TokenErrorResponse(
@@ -230,29 +204,20 @@ class TokenHandler:
) )
# Parse scopes if provided # Parse scopes if provided
scopes = ( scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes
token_request.scope.split(" ")
if token_request.scope
else refresh_token.scopes
)
for scope in scopes: for scope in scopes:
if scope not in refresh_token.scopes: if scope not in refresh_token.scopes:
return self.response( return self.response(
TokenErrorResponse( TokenErrorResponse(
error="invalid_scope", error="invalid_scope",
error_description=( error_description=(f"cannot request scope `{scope}` " "not provided by refresh token"),
f"cannot request scope `{scope}` "
"not provided by refresh token"
),
) )
) )
try: try:
# Exchange refresh token for new tokens # Exchange refresh token for new tokens
tokens = await self.provider.exchange_refresh_token( tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes)
client_info, refresh_token, scopes
)
except TokenError as e: except TokenError as e:
return self.response( return self.response(
TokenErrorResponse( TokenErrorResponse(

View File

@@ -7,9 +7,7 @@ from mcp.server.auth.provider import AccessToken
# Create a contextvar to store the authenticated user # Create a contextvar to store the authenticated user
# The default is None, indicating no authenticated user is present # The default is None, indicating no authenticated user is present
auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]( auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]("auth_context", default=None)
"auth_context", default=None
)
def get_access_token() -> AccessToken | None: def get_access_token() -> AccessToken | None:

View File

@@ -1,11 +1,7 @@
import time import time
from typing import Any from typing import Any
from starlette.authentication import ( from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser
AuthCredentials,
AuthenticationBackend,
SimpleUser,
)
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection from starlette.requests import HTTPConnection
from starlette.types import Receive, Scope, Send from starlette.types import Receive, Scope, Send
@@ -35,11 +31,7 @@ class BearerAuthBackend(AuthenticationBackend):
async def authenticate(self, conn: HTTPConnection): async def authenticate(self, conn: HTTPConnection):
auth_header = next( auth_header = next(
( (conn.headers.get(key) for key in conn.headers if key.lower() == "authorization"),
conn.headers.get(key)
for key in conn.headers
if key.lower() == "authorization"
),
None, None,
) )
if not auth_header or not auth_header.lower().startswith("bearer "): if not auth_header or not auth_header.lower().startswith("bearer "):
@@ -87,10 +79,7 @@ class RequireAuthMiddleware:
for required_scope in self.required_scopes: for required_scope in self.required_scopes:
# auth_credentials should always be provided; this is just paranoia # auth_credentials should always be provided; this is just paranoia
if ( if auth_credentials is None or required_scope not in auth_credentials.scopes:
auth_credentials is None
or required_scope not in auth_credentials.scopes
):
raise HTTPException(status_code=403, detail="Insufficient scope") raise HTTPException(status_code=403, detail="Insufficient scope")
await self.app(scope, receive, send) await self.app(scope, receive, send)

View File

@@ -30,9 +30,7 @@ class ClientAuthenticator:
""" """
self.provider = provider self.provider = provider
async def authenticate( async def authenticate(self, client_id: str, client_secret: str | None) -> OAuthClientInformationFull:
self, client_id: str, client_secret: str | None
) -> OAuthClientInformationFull:
# Look up client information # Look up client information
client = await self.provider.get_client(client_id) client = await self.provider.get_client(client_id)
if not client: if not client:
@@ -47,10 +45,7 @@ class ClientAuthenticator:
if client.client_secret != client_secret: if client.client_secret != client_secret:
raise AuthenticationError("Invalid client_secret") raise AuthenticationError("Invalid client_secret")
if ( if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()):
client.client_secret_expires_at
and client.client_secret_expires_at < int(time.time())
):
raise AuthenticationError("Client secret has expired") raise AuthenticationError("Client secret has expired")
return client return client

View File

@@ -4,10 +4,7 @@ from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from pydantic import AnyUrl, BaseModel from pydantic import AnyUrl, BaseModel
from mcp.shared.auth import ( from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
OAuthClientInformationFull,
OAuthToken,
)
class AuthorizationParams(BaseModel): class AuthorizationParams(BaseModel):
@@ -96,9 +93,7 @@ RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken)
AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken) AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken)
class OAuthAuthorizationServerProvider( class OAuthAuthorizationServerProvider(Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]):
Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]
):
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
""" """
Retrieves client information by client ID. Retrieves client information by client ID.
@@ -129,9 +124,7 @@ class OAuthAuthorizationServerProvider(
""" """
... ...
async def authorize( async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
self, client: OAuthClientInformationFull, params: AuthorizationParams
) -> str:
""" """
Called as part of the /authorize endpoint, and returns a URL that the client Called as part of the /authorize endpoint, and returns a URL that the client
will be redirected to. will be redirected to.
@@ -207,9 +200,7 @@ class OAuthAuthorizationServerProvider(
""" """
... ...
async def load_refresh_token( async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshTokenT | None:
self, client: OAuthClientInformationFull, refresh_token: str
) -> RefreshTokenT | None:
""" """
Loads a RefreshToken by its token string. Loads a RefreshToken by its token string.

View File

@@ -31,11 +31,7 @@ def validate_issuer_url(url: AnyHttpUrl):
""" """
# RFC 8414 requires HTTPS, but we allow localhost HTTP for testing # RFC 8414 requires HTTPS, but we allow localhost HTTP for testing
if ( if url.scheme != "https" and url.host != "localhost" and not url.host.startswith("127.0.0.1"):
url.scheme != "https"
and url.host != "localhost"
and not url.host.startswith("127.0.0.1")
):
raise ValueError("Issuer URL must be HTTPS") raise ValueError("Issuer URL must be HTTPS")
# No fragments or query parameters allowed # No fragments or query parameters allowed
@@ -73,9 +69,7 @@ def create_auth_routes(
) -> list[Route]: ) -> list[Route]:
validate_issuer_url(issuer_url) validate_issuer_url(issuer_url)
client_registration_options = ( client_registration_options = client_registration_options or ClientRegistrationOptions()
client_registration_options or ClientRegistrationOptions()
)
revocation_options = revocation_options or RevocationOptions() revocation_options = revocation_options or RevocationOptions()
metadata = build_metadata( metadata = build_metadata(
issuer_url, issuer_url,
@@ -177,15 +171,11 @@ def build_metadata(
# Add registration endpoint if supported # Add registration endpoint if supported
if client_registration_options.enabled: if client_registration_options.enabled:
metadata.registration_endpoint = AnyHttpUrl( metadata.registration_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REGISTRATION_PATH)
str(issuer_url).rstrip("/") + REGISTRATION_PATH
)
# Add revocation endpoint if supported # Add revocation endpoint if supported
if revocation_options.enabled: if revocation_options.enabled:
metadata.revocation_endpoint = AnyHttpUrl( metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH)
str(issuer_url).rstrip("/") + REVOCATION_PATH
)
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"] metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"]
return metadata return metadata

View File

@@ -15,8 +15,7 @@ class RevocationOptions(BaseModel):
class AuthSettings(BaseModel): class AuthSettings(BaseModel):
issuer_url: AnyHttpUrl = Field( issuer_url: AnyHttpUrl = Field(
..., ...,
description="URL advertised as OAuth issuer; this should be the URL the server " description="URL advertised as OAuth issuer; this should be the URL the server " "is reachable at",
"is reachable at",
) )
service_documentation_url: AnyHttpUrl | None = None service_documentation_url: AnyHttpUrl | None = None
client_registration_options: ClientRegistrationOptions | None = None client_registration_options: ClientRegistrationOptions | None = None

View File

@@ -42,13 +42,9 @@ class AssistantMessage(Message):
super().__init__(content=content, **kwargs) super().__init__(content=content, **kwargs)
message_validator = TypeAdapter[UserMessage | AssistantMessage]( message_validator = TypeAdapter[UserMessage | AssistantMessage](UserMessage | AssistantMessage)
UserMessage | AssistantMessage
)
SyncPromptResult = ( SyncPromptResult = str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
)
PromptResult = SyncPromptResult | Awaitable[SyncPromptResult] PromptResult = SyncPromptResult | Awaitable[SyncPromptResult]
@@ -56,24 +52,16 @@ class PromptArgument(BaseModel):
"""An argument that can be passed to a prompt.""" """An argument that can be passed to a prompt."""
name: str = Field(description="Name of the argument") name: str = Field(description="Name of the argument")
description: str | None = Field( description: str | None = Field(None, description="Description of what the argument does")
None, description="Description of what the argument does" required: bool = Field(default=False, description="Whether the argument is required")
)
required: bool = Field(
default=False, description="Whether the argument is required"
)
class Prompt(BaseModel): class Prompt(BaseModel):
"""A prompt template that can be rendered with parameters.""" """A prompt template that can be rendered with parameters."""
name: str = Field(description="Name of the prompt") name: str = Field(description="Name of the prompt")
description: str | None = Field( description: str | None = Field(None, description="Description of what the prompt does")
None, description="Description of what the prompt does" arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt")
)
arguments: list[PromptArgument] | None = Field(
None, description="Arguments that can be passed to the prompt"
)
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True) fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
@classmethod @classmethod
@@ -154,14 +142,10 @@ class Prompt(BaseModel):
content = TextContent(type="text", text=msg) content = TextContent(type="text", text=msg)
messages.append(UserMessage(content=content)) messages.append(UserMessage(content=content))
else: else:
content = pydantic_core.to_json( content = pydantic_core.to_json(msg, fallback=str, indent=2).decode()
msg, fallback=str, indent=2
).decode()
messages.append(Message(role="user", content=content)) messages.append(Message(role="user", content=content))
except Exception: except Exception:
raise ValueError( raise ValueError(f"Could not convert prompt result to message: {msg}")
f"Could not convert prompt result to message: {msg}"
)
return messages return messages
except Exception as e: except Exception as e:

View File

@@ -39,9 +39,7 @@ class PromptManager:
self._prompts[prompt.name] = prompt self._prompts[prompt.name] = prompt
return prompt return prompt
async def render_prompt( async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]:
self, name: str, arguments: dict[str, Any] | None = None
) -> list[Message]:
"""Render a prompt by name with arguments.""" """Render a prompt by name with arguments."""
prompt = self.get_prompt(name) prompt = self.get_prompt(name)
if not prompt: if not prompt:

View File

@@ -19,13 +19,9 @@ class Resource(BaseModel, abc.ABC):
model_config = ConfigDict(validate_default=True) model_config = ConfigDict(validate_default=True)
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field( uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(default=..., description="URI of the resource")
default=..., description="URI of the resource"
)
name: str | None = Field(description="Name of the resource", default=None) name: str | None = Field(description="Name of the resource", default=None)
description: str | None = Field( description: str | None = Field(description="Description of the resource", default=None)
description="Description of the resource", default=None
)
mime_type: str = Field( mime_type: str = Field(
default="text/plain", default="text/plain",
description="MIME type of the resource content", description="MIME type of the resource content",

View File

@@ -15,18 +15,12 @@ from mcp.server.fastmcp.resources.types import FunctionResource, Resource
class ResourceTemplate(BaseModel): class ResourceTemplate(BaseModel):
"""A template for dynamically creating resources.""" """A template for dynamically creating resources."""
uri_template: str = Field( uri_template: str = Field(description="URI template with parameters (e.g. weather://{city}/current)")
description="URI template with parameters (e.g. weather://{city}/current)"
)
name: str = Field(description="Name of the resource") name: str = Field(description="Name of the resource")
description: str | None = Field(description="Description of what the resource does") description: str | None = Field(description="Description of what the resource does")
mime_type: str = Field( mime_type: str = Field(default="text/plain", description="MIME type of the resource content")
default="text/plain", description="MIME type of the resource content"
)
fn: Callable[..., Any] = Field(exclude=True) fn: Callable[..., Any] = Field(exclude=True)
parameters: dict[str, Any] = Field( parameters: dict[str, Any] = Field(description="JSON schema for function parameters")
description="JSON schema for function parameters"
)
@classmethod @classmethod
def from_function( def from_function(

View File

@@ -54,9 +54,7 @@ class FunctionResource(Resource):
async def read(self) -> str | bytes: async def read(self) -> str | bytes:
"""Read the resource by calling the wrapped function.""" """Read the resource by calling the wrapped function."""
try: try:
result = ( result = await self.fn() if inspect.iscoroutinefunction(self.fn) else self.fn()
await self.fn() if inspect.iscoroutinefunction(self.fn) else self.fn()
)
if isinstance(result, Resource): if isinstance(result, Resource):
return await result.read() return await result.read()
elif isinstance(result, bytes): elif isinstance(result, bytes):
@@ -141,9 +139,7 @@ class HttpResource(Resource):
"""A resource that reads from an HTTP endpoint.""" """A resource that reads from an HTTP endpoint."""
url: str = Field(description="URL to fetch content from") url: str = Field(description="URL to fetch content from")
mime_type: str = Field( mime_type: str = Field(default="application/json", description="MIME type of the resource content")
default="application/json", description="MIME type of the resource content"
)
async def read(self) -> str | bytes: async def read(self) -> str | bytes:
"""Read the HTTP content.""" """Read the HTTP content."""
@@ -157,15 +153,9 @@ class DirectoryResource(Resource):
"""A resource that lists files in a directory.""" """A resource that lists files in a directory."""
path: Path = Field(description="Path to the directory") path: Path = Field(description="Path to the directory")
recursive: bool = Field( recursive: bool = Field(default=False, description="Whether to list files recursively")
default=False, description="Whether to list files recursively" pattern: str | None = Field(default=None, description="Optional glob pattern to filter files")
) mime_type: str = Field(default="application/json", description="MIME type of the resource content")
pattern: str | None = Field(
default=None, description="Optional glob pattern to filter files"
)
mime_type: str = Field(
default="application/json", description="MIME type of the resource content"
)
@pydantic.field_validator("path") @pydantic.field_validator("path")
@classmethod @classmethod
@@ -184,16 +174,8 @@ class DirectoryResource(Resource):
try: try:
if self.pattern: if self.pattern:
return ( return list(self.path.glob(self.pattern)) if not self.recursive else list(self.path.rglob(self.pattern))
list(self.path.glob(self.pattern)) return list(self.path.glob("*")) if not self.recursive else list(self.path.rglob("*"))
if not self.recursive
else list(self.path.rglob(self.pattern))
)
return (
list(self.path.glob("*"))
if not self.recursive
else list(self.path.rglob("*"))
)
except Exception as e: except Exception as e:
raise ValueError(f"Error listing directory {self.path}: {e}") raise ValueError(f"Error listing directory {self.path}: {e}")

View File

@@ -97,9 +97,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
# StreamableHTTP settings # StreamableHTTP settings
json_response: bool = False json_response: bool = False
stateless_http: bool = ( stateless_http: bool = False # If True, uses true stateless mode (new transport per request)
False # If True, uses true stateless mode (new transport per request)
)
# resource settings # resource settings
warn_on_duplicate_resources: bool = True warn_on_duplicate_resources: bool = True
@@ -115,9 +113,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
description="List of dependencies to install in the server environment", description="List of dependencies to install in the server environment",
) )
lifespan: ( lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None = Field(
Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None None, description="Lifespan context manager"
) = Field(None, description="Lifespan context manager") )
auth: AuthSettings | None = None auth: AuthSettings | None = None
@@ -125,9 +123,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
def lifespan_wrapper( def lifespan_wrapper(
app: FastMCP, app: FastMCP,
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]], lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
) -> Callable[ ) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]]:
[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]
]:
@asynccontextmanager @asynccontextmanager
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]: async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
async with lifespan(app) as context: async with lifespan(app) as context:
@@ -141,8 +137,7 @@ class FastMCP:
self, self,
name: str | None = None, name: str | None = None,
instructions: str | None = None, instructions: str | None = None,
auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None,
| None = None,
event_store: EventStore | None = None, event_store: EventStore | None = None,
*, *,
tools: list[Tool] | None = None, tools: list[Tool] | None = None,
@@ -153,31 +148,18 @@ class FastMCP:
self._mcp_server = MCPServer( self._mcp_server = MCPServer(
name=name or "FastMCP", name=name or "FastMCP",
instructions=instructions, instructions=instructions,
lifespan=( lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan),
lifespan_wrapper(self, self.settings.lifespan)
if self.settings.lifespan
else default_lifespan
),
)
self._tool_manager = ToolManager(
tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
)
self._resource_manager = ResourceManager(
warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources
)
self._prompt_manager = PromptManager(
warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts
) )
self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools)
self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources)
self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts)
if (self.settings.auth is not None) != (auth_server_provider is not None): if (self.settings.auth is not None) != (auth_server_provider is not None):
# TODO: after we support separate authorization servers (see # TODO: after we support separate authorization servers (see
# https://github.com/modelcontextprotocol/modelcontextprotocol/pull/284) # https://github.com/modelcontextprotocol/modelcontextprotocol/pull/284)
# we should validate that if auth is enabled, we have either an # we should validate that if auth is enabled, we have either an
# auth_server_provider to host our own authorization server, # auth_server_provider to host our own authorization server,
# OR the URL of a 3rd party authorization server. # OR the URL of a 3rd party authorization server.
raise ValueError( raise ValueError("settings.auth must be specified if and only if auth_server_provider " "is specified")
"settings.auth must be specified if and only if auth_server_provider "
"is specified"
)
self._auth_server_provider = auth_server_provider self._auth_server_provider = auth_server_provider
self._event_store = event_store self._event_store = event_store
self._custom_starlette_routes: list[Route] = [] self._custom_starlette_routes: list[Route] = []
@@ -340,9 +322,7 @@ class FastMCP:
description: Optional description of what the tool does description: Optional description of what the tool does
annotations: Optional ToolAnnotations providing additional tool information annotations: Optional ToolAnnotations providing additional tool information
""" """
self._tool_manager.add_tool( self._tool_manager.add_tool(fn, name=name, description=description, annotations=annotations)
fn, name=name, description=description, annotations=annotations
)
def tool( def tool(
self, self,
@@ -379,14 +359,11 @@ class FastMCP:
# Check if user passed function directly instead of calling decorator # Check if user passed function directly instead of calling decorator
if callable(name): if callable(name):
raise TypeError( raise TypeError(
"The @tool decorator was used incorrectly. " "The @tool decorator was used incorrectly. " "Did you forget to call it? Use @tool() instead of @tool"
"Did you forget to call it? Use @tool() instead of @tool"
) )
def decorator(fn: AnyFunction) -> AnyFunction: def decorator(fn: AnyFunction) -> AnyFunction:
self.add_tool( self.add_tool(fn, name=name, description=description, annotations=annotations)
fn, name=name, description=description, annotations=annotations
)
return fn return fn
return decorator return decorator
@@ -462,8 +439,7 @@ class FastMCP:
if uri_params != func_params: if uri_params != func_params:
raise ValueError( raise ValueError(
f"Mismatch between URI parameters {uri_params} " f"Mismatch between URI parameters {uri_params} " f"and function parameters {func_params}"
f"and function parameters {func_params}"
) )
# Register as template # Register as template
@@ -496,9 +472,7 @@ class FastMCP:
""" """
self._prompt_manager.add_prompt(prompt) self._prompt_manager.add_prompt(prompt)
def prompt( def prompt(self, name: str | None = None, description: str | None = None) -> Callable[[AnyFunction], AnyFunction]:
self, name: str | None = None, description: str | None = None
) -> Callable[[AnyFunction], AnyFunction]:
"""Decorator to register a prompt. """Decorator to register a prompt.
Args: Args:
@@ -665,9 +639,7 @@ class FastMCP:
self.settings.mount_path = mount_path self.settings.mount_path = mount_path
# Create normalized endpoint considering the mount path # Create normalized endpoint considering the mount path
normalized_message_endpoint = self._normalize_path( normalized_message_endpoint = self._normalize_path(self.settings.mount_path, self.settings.message_path)
self.settings.mount_path, self.settings.message_path
)
# Set up auth context and dependencies # Set up auth context and dependencies
@@ -764,9 +736,7 @@ class FastMCP:
routes.extend(self._custom_starlette_routes) routes.extend(self._custom_starlette_routes)
# Create Starlette app with routes and middleware # Create Starlette app with routes and middleware
return Starlette( return Starlette(debug=self.settings.debug, routes=routes, middleware=middleware)
debug=self.settings.debug, routes=routes, middleware=middleware
)
def streamable_http_app(self) -> Starlette: def streamable_http_app(self) -> Starlette:
"""Return an instance of the StreamableHTTP server app.""" """Return an instance of the StreamableHTTP server app."""
@@ -783,9 +753,7 @@ class FastMCP:
) )
# Create the ASGI handler # Create the ASGI handler
async def handle_streamable_http( async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
scope: Scope, receive: Receive, send: Send
) -> None:
await self.session_manager.handle_request(scope, receive, send) await self.session_manager.handle_request(scope, receive, send)
# Create routes # Create routes
@@ -861,9 +829,7 @@ class FastMCP:
for prompt in prompts for prompt in prompts
] ]
async def get_prompt( async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult:
self, name: str, arguments: dict[str, Any] | None = None
) -> GetPromptResult:
"""Get a prompt by name with arguments.""" """Get a prompt by name with arguments."""
try: try:
messages = await self._prompt_manager.render_prompt(name, arguments) messages = await self._prompt_manager.render_prompt(name, arguments)
@@ -936,9 +902,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
def __init__( def __init__(
self, self,
*, *,
request_context: ( request_context: (RequestContext[ServerSessionT, LifespanContextT, RequestT] | None) = None,
RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
) = None,
fastmcp: FastMCP | None = None, fastmcp: FastMCP | None = None,
**kwargs: Any, **kwargs: Any,
): ):
@@ -962,9 +926,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
raise ValueError("Context is not available outside of a request") raise ValueError("Context is not available outside of a request")
return self._request_context return self._request_context
async def report_progress( async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None:
self, progress: float, total: float | None = None, message: str | None = None
) -> None:
"""Report progress for the current operation. """Report progress for the current operation.
Args: Args:
@@ -972,11 +934,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
total: Optional total value e.g. 100 total: Optional total value e.g. 100
message: Optional message e.g. Starting render... message: Optional message e.g. Starting render...
""" """
progress_token = ( progress_token = self.request_context.meta.progressToken if self.request_context.meta else None
self.request_context.meta.progressToken
if self.request_context.meta
else None
)
if progress_token is None: if progress_token is None:
return return
@@ -997,9 +955,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
Returns: Returns:
The resource content as either text or bytes The resource content as either text or bytes
""" """
assert ( assert self._fastmcp is not None, "Context is not available outside of a request"
self._fastmcp is not None
), "Context is not available outside of a request"
return await self._fastmcp.read_resource(uri) return await self._fastmcp.read_resource(uri)
async def log( async def log(
@@ -1027,11 +983,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
@property @property
def client_id(self) -> str | None: def client_id(self) -> str | None:
"""Get the client ID if available.""" """Get the client ID if available."""
return ( return getattr(self.request_context.meta, "client_id", None) if self.request_context.meta else None
getattr(self.request_context.meta, "client_id", None)
if self.request_context.meta
else None
)
@property @property
def request_id(self) -> str: def request_id(self) -> str:

View File

@@ -25,16 +25,11 @@ class Tool(BaseModel):
description: str = Field(description="Description of what the tool does") description: str = Field(description="Description of what the tool does")
parameters: dict[str, Any] = Field(description="JSON schema for tool parameters") parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
fn_metadata: FuncMetadata = Field( fn_metadata: FuncMetadata = Field(
description="Metadata about the function including a pydantic model for tool" description="Metadata about the function including a pydantic model for tool" " arguments"
" arguments"
) )
is_async: bool = Field(description="Whether the tool is async") is_async: bool = Field(description="Whether the tool is async")
context_kwarg: str | None = Field( context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
None, description="Name of the kwarg that should receive context" annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool")
)
annotations: ToolAnnotations | None = Field(
None, description="Optional annotations for the tool"
)
@classmethod @classmethod
def from_function( def from_function(
@@ -93,9 +88,7 @@ class Tool(BaseModel):
self.fn, self.fn,
self.is_async, self.is_async,
arguments, arguments,
{self.context_kwarg: context} {self.context_kwarg: context} if self.context_kwarg is not None else None,
if self.context_kwarg is not None
else None,
) )
except Exception as e: except Exception as e:
raise ToolError(f"Error executing tool {self.name}: {e}") from e raise ToolError(f"Error executing tool {self.name}: {e}") from e

View File

@@ -50,9 +50,7 @@ class ToolManager:
annotations: ToolAnnotations | None = None, annotations: ToolAnnotations | None = None,
) -> Tool: ) -> Tool:
"""Add a tool to the server.""" """Add a tool to the server."""
tool = Tool.from_function( tool = Tool.from_function(fn, name=name, description=description, annotations=annotations)
fn, name=name, description=description, annotations=annotations
)
existing = self._tools.get(tool.name) existing = self._tools.get(tool.name)
if existing: if existing:
if self.warn_on_duplicate_tools: if self.warn_on_duplicate_tools:

View File

@@ -102,9 +102,7 @@ class FuncMetadata(BaseModel):
) )
def func_metadata( def func_metadata(func: Callable[..., Any], skip_names: Sequence[str] = ()) -> FuncMetadata:
func: Callable[..., Any], skip_names: Sequence[str] = ()
) -> FuncMetadata:
"""Given a function, return metadata including a pydantic model representing its """Given a function, return metadata including a pydantic model representing its
signature. signature.
@@ -131,9 +129,7 @@ def func_metadata(
globalns = getattr(func, "__globals__", {}) globalns = getattr(func, "__globals__", {})
for param in params.values(): for param in params.values():
if param.name.startswith("_"): if param.name.startswith("_"):
raise InvalidSignature( raise InvalidSignature(f"Parameter {param.name} of {func.__name__} cannot start with '_'")
f"Parameter {param.name} of {func.__name__} cannot start with '_'"
)
if param.name in skip_names: if param.name in skip_names:
continue continue
annotation = param.annotation annotation = param.annotation
@@ -142,11 +138,7 @@ def func_metadata(
if annotation is None: if annotation is None:
annotation = Annotated[ annotation = Annotated[
None, None,
Field( Field(default=param.default if param.default is not inspect.Parameter.empty else PydanticUndefined),
default=param.default
if param.default is not inspect.Parameter.empty
else PydanticUndefined
),
] ]
# Untyped field # Untyped field
@@ -160,9 +152,7 @@ def func_metadata(
field_info = FieldInfo.from_annotated_attribute( field_info = FieldInfo.from_annotated_attribute(
_get_typed_annotation(annotation, globalns), _get_typed_annotation(annotation, globalns),
param.default param.default if param.default is not inspect.Parameter.empty else PydanticUndefined,
if param.default is not inspect.Parameter.empty
else PydanticUndefined,
) )
dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info) dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info)
continue continue
@@ -177,9 +167,7 @@ def func_metadata(
def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any: def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
def try_eval_type( def try_eval_type(value: Any, globalns: dict[str, Any], localns: dict[str, Any]) -> tuple[Any, bool]:
value: Any, globalns: dict[str, Any], localns: dict[str, Any]
) -> tuple[Any, bool]:
try: try:
return eval_type_backport(value, globalns, localns), True return eval_type_backport(value, globalns, localns), True
except NameError: except NameError:

View File

@@ -95,9 +95,7 @@ LifespanResultT = TypeVar("LifespanResultT")
RequestT = TypeVar("RequestT", default=Any) RequestT = TypeVar("RequestT", default=Any)
# This will be properly typed in each Server instance's context # This will be properly typed in each Server instance's context
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = ( request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx")
contextvars.ContextVar("request_ctx")
)
class NotificationOptions: class NotificationOptions:
@@ -140,9 +138,7 @@ class Server(Generic[LifespanResultT, RequestT]):
self.version = version self.version = version
self.instructions = instructions self.instructions = instructions
self.lifespan = lifespan self.lifespan = lifespan
self.request_handlers: dict[ self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = {
type, Callable[..., Awaitable[types.ServerResult]]
] = {
types.PingRequest: _ping_handler, types.PingRequest: _ping_handler,
} }
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
@@ -189,9 +185,7 @@ class Server(Generic[LifespanResultT, RequestT]):
# Set prompt capabilities if handler exists # Set prompt capabilities if handler exists
if types.ListPromptsRequest in self.request_handlers: if types.ListPromptsRequest in self.request_handlers:
prompts_capability = types.PromptsCapability( prompts_capability = types.PromptsCapability(listChanged=notification_options.prompts_changed)
listChanged=notification_options.prompts_changed
)
# Set resource capabilities if handler exists # Set resource capabilities if handler exists
if types.ListResourcesRequest in self.request_handlers: if types.ListResourcesRequest in self.request_handlers:
@@ -201,9 +195,7 @@ class Server(Generic[LifespanResultT, RequestT]):
# Set tool capabilities if handler exists # Set tool capabilities if handler exists
if types.ListToolsRequest in self.request_handlers: if types.ListToolsRequest in self.request_handlers:
tools_capability = types.ToolsCapability( tools_capability = types.ToolsCapability(listChanged=notification_options.tools_changed)
listChanged=notification_options.tools_changed
)
# Set logging capabilities if handler exists # Set logging capabilities if handler exists
if types.SetLevelRequest in self.request_handlers: if types.SetLevelRequest in self.request_handlers:
@@ -239,9 +231,7 @@ class Server(Generic[LifespanResultT, RequestT]):
def get_prompt(self): def get_prompt(self):
def decorator( def decorator(
func: Callable[ func: Callable[[str, dict[str, str] | None], Awaitable[types.GetPromptResult]],
[str, dict[str, str] | None], Awaitable[types.GetPromptResult]
],
): ):
logger.debug("Registering handler for GetPromptRequest") logger.debug("Registering handler for GetPromptRequest")
@@ -260,9 +250,7 @@ class Server(Generic[LifespanResultT, RequestT]):
async def handler(_: Any): async def handler(_: Any):
resources = await func() resources = await func()
return types.ServerResult( return types.ServerResult(types.ListResourcesResult(resources=resources))
types.ListResourcesResult(resources=resources)
)
self.request_handlers[types.ListResourcesRequest] = handler self.request_handlers[types.ListResourcesRequest] = handler
return func return func
@@ -275,9 +263,7 @@ class Server(Generic[LifespanResultT, RequestT]):
async def handler(_: Any): async def handler(_: Any):
templates = await func() templates = await func()
return types.ServerResult( return types.ServerResult(types.ListResourceTemplatesResult(resourceTemplates=templates))
types.ListResourceTemplatesResult(resourceTemplates=templates)
)
self.request_handlers[types.ListResourceTemplatesRequest] = handler self.request_handlers[types.ListResourceTemplatesRequest] = handler
return func return func
@@ -286,9 +272,7 @@ class Server(Generic[LifespanResultT, RequestT]):
def read_resource(self): def read_resource(self):
def decorator( def decorator(
func: Callable[ func: Callable[[AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]]],
[AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]]
],
): ):
logger.debug("Registering handler for ReadResourceRequest") logger.debug("Registering handler for ReadResourceRequest")
@@ -323,8 +307,7 @@ class Server(Generic[LifespanResultT, RequestT]):
content = create_content(data, None) content = create_content(data, None)
case Iterable() as contents: case Iterable() as contents:
contents_list = [ contents_list = [
create_content(content_item.content, content_item.mime_type) create_content(content_item.content, content_item.mime_type) for content_item in contents
for content_item in contents
] ]
return types.ServerResult( return types.ServerResult(
types.ReadResourceResult( types.ReadResourceResult(
@@ -332,9 +315,7 @@ class Server(Generic[LifespanResultT, RequestT]):
) )
) )
case _: case _:
raise ValueError( raise ValueError(f"Unexpected return type from read_resource: {type(result)}")
f"Unexpected return type from read_resource: {type(result)}"
)
return types.ServerResult( return types.ServerResult(
types.ReadResourceResult( types.ReadResourceResult(
@@ -404,12 +385,7 @@ class Server(Generic[LifespanResultT, RequestT]):
func: Callable[ func: Callable[
..., ...,
Awaitable[ Awaitable[
Iterable[ Iterable[types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource]
types.TextContent
| types.ImageContent
| types.AudioContent
| types.EmbeddedResource
]
], ],
], ],
): ):
@@ -418,9 +394,7 @@ class Server(Generic[LifespanResultT, RequestT]):
async def handler(req: types.CallToolRequest): async def handler(req: types.CallToolRequest):
try: try:
results = await func(req.params.name, (req.params.arguments or {})) results = await func(req.params.name, (req.params.arguments or {}))
return types.ServerResult( return types.ServerResult(types.CallToolResult(content=list(results), isError=False))
types.CallToolResult(content=list(results), isError=False)
)
except Exception as e: except Exception as e:
return types.ServerResult( return types.ServerResult(
types.CallToolResult( types.CallToolResult(
@@ -436,9 +410,7 @@ class Server(Generic[LifespanResultT, RequestT]):
def progress_notification(self): def progress_notification(self):
def decorator( def decorator(
func: Callable[ func: Callable[[str | int, float, float | None, str | None], Awaitable[None]],
[str | int, float, float | None, str | None], Awaitable[None]
],
): ):
logger.debug("Registering handler for ProgressNotification") logger.debug("Registering handler for ProgressNotification")
@@ -525,9 +497,7 @@ class Server(Generic[LifespanResultT, RequestT]):
async def _handle_message( async def _handle_message(
self, self,
message: RequestResponder[types.ClientRequest, types.ServerResult] message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception,
| types.ClientNotification
| Exception,
session: ServerSession, session: ServerSession,
lifespan_context: LifespanResultT, lifespan_context: LifespanResultT,
raise_exceptions: bool = False, raise_exceptions: bool = False,
@@ -535,20 +505,14 @@ class Server(Generic[LifespanResultT, RequestT]):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
# TODO(Marcelo): We should be checking if message is Exception here. # TODO(Marcelo): We should be checking if message is Exception here.
match message: # type: ignore[reportMatchNotExhaustive] match message: # type: ignore[reportMatchNotExhaustive]
case ( case RequestResponder(request=types.ClientRequest(root=req)) as responder:
RequestResponder(request=types.ClientRequest(root=req)) as responder
):
with responder: with responder:
await self._handle_request( await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
message, req, session, lifespan_context, raise_exceptions
)
case types.ClientNotification(root=notify): case types.ClientNotification(root=notify):
await self._handle_notification(notify) await self._handle_notification(notify)
for warning in w: for warning in w:
logger.info( logger.info("Warning: %s: %s", warning.category.__name__, warning.message)
"Warning: %s: %s", warning.category.__name__, warning.message
)
async def _handle_request( async def _handle_request(
self, self,
@@ -566,9 +530,7 @@ class Server(Generic[LifespanResultT, RequestT]):
try: try:
# Extract request context from message metadata # Extract request context from message metadata
request_data = None request_data = None
if message.message_metadata is not None and isinstance( if message.message_metadata is not None and isinstance(message.message_metadata, ServerMessageMetadata):
message.message_metadata, ServerMessageMetadata
):
request_data = message.message_metadata.request_context request_data = message.message_metadata.request_context
# Set our global state that can be retrieved via # Set our global state that can be retrieved via

View File

@@ -64,9 +64,7 @@ class InitializationState(Enum):
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
ServerRequestResponder = ( ServerRequestResponder = (
RequestResponder[types.ClientRequest, types.ServerResult] RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception
| types.ClientNotification
| Exception
) )
@@ -89,22 +87,16 @@ class ServerSession(
init_options: InitializationOptions, init_options: InitializationOptions,
stateless: bool = False, stateless: bool = False,
) -> None: ) -> None:
super().__init__( super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification)
read_stream, write_stream, types.ClientRequest, types.ClientNotification
)
self._initialization_state = ( self._initialization_state = (
InitializationState.Initialized InitializationState.Initialized if stateless else InitializationState.NotInitialized
if stateless
else InitializationState.NotInitialized
) )
self._init_options = init_options self._init_options = init_options
self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[
anyio.create_memory_object_stream[ServerRequestResponder](0) ServerRequestResponder
) ](0)
self._exit_stack.push_async_callback( self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose())
lambda: self._incoming_message_stream_reader.aclose()
)
@property @property
def client_params(self) -> types.InitializeRequestParams | None: def client_params(self) -> types.InitializeRequestParams | None:
@@ -134,10 +126,7 @@ class ServerSession(
return False return False
# Check each experimental capability # Check each experimental capability
for exp_key, exp_value in capability.experimental.items(): for exp_key, exp_value in capability.experimental.items():
if ( if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value:
exp_key not in client_caps.experimental
or client_caps.experimental[exp_key] != exp_value
):
return False return False
return True return True
@@ -146,9 +135,7 @@ class ServerSession(
async with self._incoming_message_stream_writer: async with self._incoming_message_stream_writer:
await super()._receive_loop() await super()._receive_loop()
async def _received_request( async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]):
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
):
match responder.request.root: match responder.request.root:
case types.InitializeRequest(params=params): case types.InitializeRequest(params=params):
requested_version = params.protocolVersion requested_version = params.protocolVersion
@@ -172,13 +159,9 @@ class ServerSession(
) )
case _: case _:
if self._initialization_state != InitializationState.Initialized: if self._initialization_state != InitializationState.Initialized:
raise RuntimeError( raise RuntimeError("Received request before initialization was complete")
"Received request before initialization was complete"
)
async def _received_notification( async def _received_notification(self, notification: types.ClientNotification) -> None:
self, notification: types.ClientNotification
) -> None:
# Need this to avoid ASYNC910 # Need this to avoid ASYNC910
await anyio.lowlevel.checkpoint() await anyio.lowlevel.checkpoint()
match notification.root: match notification.root:
@@ -186,9 +169,7 @@ class ServerSession(
self._initialization_state = InitializationState.Initialized self._initialization_state = InitializationState.Initialized
case _: case _:
if self._initialization_state != InitializationState.Initialized: if self._initialization_state != InitializationState.Initialized:
raise RuntimeError( raise RuntimeError("Received notification before initialization was complete")
"Received notification before initialization was complete"
)
async def send_log_message( async def send_log_message(
self, self,

View File

@@ -116,20 +116,14 @@ class SseServerTransport:
full_message_path_for_client = root_path.rstrip("/") + self._endpoint full_message_path_for_client = root_path.rstrip("/") + self._endpoint
# This is the URI (path + query) the client will use to POST messages. # This is the URI (path + query) the client will use to POST messages.
client_post_uri_data = ( client_post_uri_data = f"{quote(full_message_path_for_client)}?session_id={session_id.hex}"
f"{quote(full_message_path_for_client)}?session_id={session_id.hex}"
)
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, Any]](0)
dict[str, Any]
](0)
async def sse_writer(): async def sse_writer():
logger.debug("Starting SSE writer") logger.debug("Starting SSE writer")
async with sse_stream_writer, write_stream_reader: async with sse_stream_writer, write_stream_reader:
await sse_stream_writer.send( await sse_stream_writer.send({"event": "endpoint", "data": client_post_uri_data})
{"event": "endpoint", "data": client_post_uri_data}
)
logger.debug(f"Sent endpoint event: {client_post_uri_data}") logger.debug(f"Sent endpoint event: {client_post_uri_data}")
async for session_message in write_stream_reader: async for session_message in write_stream_reader:
@@ -137,9 +131,7 @@ class SseServerTransport:
await sse_stream_writer.send( await sse_stream_writer.send(
{ {
"event": "message", "event": "message",
"data": session_message.message.model_dump_json( "data": session_message.message.model_dump_json(by_alias=True, exclude_none=True),
by_alias=True, exclude_none=True
),
} }
) )
@@ -151,9 +143,9 @@ class SseServerTransport:
In this case we close our side of the streams to signal the client that In this case we close our side of the streams to signal the client that
the connection has been closed. the connection has been closed.
""" """
await EventSourceResponse( await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
content=sse_stream_reader, data_sender_callable=sse_writer scope, receive, send
)(scope, receive, send) )
await read_stream_writer.aclose() await read_stream_writer.aclose()
await write_stream_reader.aclose() await write_stream_reader.aclose()
logging.debug(f"Client session disconnected {session_id}") logging.debug(f"Client session disconnected {session_id}")
@@ -164,9 +156,7 @@ class SseServerTransport:
logger.debug("Yielding read and write streams") logger.debug("Yielding read and write streams")
yield (read_stream, write_stream) yield (read_stream, write_stream)
async def handle_post_message( async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None:
self, scope: Scope, receive: Receive, send: Send
) -> None:
logger.debug("Handling POST message") logger.debug("Handling POST message")
request = Request(scope, receive) request = Request(scope, receive)

View File

@@ -76,9 +76,7 @@ async def stdio_server(
try: try:
async with write_stream_reader: async with write_stream_reader:
async for session_message in write_stream_reader: async for session_message in write_stream_reader:
json = session_message.message.model_dump_json( json = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
by_alias=True, exclude_none=True
)
await stdout.write(json + "\n") await stdout.write(json + "\n")
await stdout.flush() await stdout.flush()
except anyio.ClosedResourceError: except anyio.ClosedResourceError:

View File

@@ -82,9 +82,7 @@ class EventStore(ABC):
""" """
@abstractmethod @abstractmethod
async def store_event( async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId:
self, stream_id: StreamId, message: JSONRPCMessage
) -> EventId:
""" """
Stores an event for later retrieval. Stores an event for later retrieval.
@@ -125,9 +123,7 @@ class StreamableHTTPServerTransport:
""" """
# Server notification streams for POST requests as well as standalone SSE stream # Server notification streams for POST requests as well as standalone SSE stream
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = ( _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = None
None
)
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None _write_stream: MemoryObjectSendStream[SessionMessage] | None = None
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
@@ -153,12 +149,8 @@ class StreamableHTTPServerTransport:
Raises: Raises:
ValueError: If the session ID contains invalid characters. ValueError: If the session ID contains invalid characters.
""" """
if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch( if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch(mcp_session_id):
mcp_session_id raise ValueError("Session ID must only contain visible ASCII characters (0x21-0x7E)")
):
raise ValueError(
"Session ID must only contain visible ASCII characters (0x21-0x7E)"
)
self.mcp_session_id = mcp_session_id self.mcp_session_id = mcp_session_id
self.is_json_response_enabled = is_json_response_enabled self.is_json_response_enabled = is_json_response_enabled
@@ -218,9 +210,7 @@ class StreamableHTTPServerTransport:
response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
return Response( return Response(
response_message.model_dump_json(by_alias=True, exclude_none=True) response_message.model_dump_json(by_alias=True, exclude_none=True) if response_message else None,
if response_message
else None,
status_code=status_code, status_code=status_code,
headers=response_headers, headers=response_headers,
) )
@@ -233,9 +223,7 @@ class StreamableHTTPServerTransport:
"""Create event data dictionary from an EventMessage.""" """Create event data dictionary from an EventMessage."""
event_data = { event_data = {
"event": "message", "event": "message",
"data": event_message.message.model_dump_json( "data": event_message.message.model_dump_json(by_alias=True, exclude_none=True),
by_alias=True, exclude_none=True
),
} }
# If an event ID was provided, include it # If an event ID was provided, include it
@@ -283,42 +271,29 @@ class StreamableHTTPServerTransport:
accept_header = request.headers.get("accept", "") accept_header = request.headers.get("accept", "")
accept_types = [media_type.strip() for media_type in accept_header.split(",")] accept_types = [media_type.strip() for media_type in accept_header.split(",")]
has_json = any( has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types)
media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types)
)
has_sse = any(
media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types
)
return has_json, has_sse return has_json, has_sse
def _check_content_type(self, request: Request) -> bool: def _check_content_type(self, request: Request) -> bool:
"""Check if the request has the correct Content-Type.""" """Check if the request has the correct Content-Type."""
content_type = request.headers.get("content-type", "") content_type = request.headers.get("content-type", "")
content_type_parts = [ content_type_parts = [part.strip() for part in content_type.split(";")[0].split(",")]
part.strip() for part in content_type.split(";")[0].split(",")
]
return any(part == CONTENT_TYPE_JSON for part in content_type_parts) return any(part == CONTENT_TYPE_JSON for part in content_type_parts)
async def _handle_post_request( async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None:
self, scope: Scope, request: Request, receive: Receive, send: Send
) -> None:
"""Handle POST requests containing JSON-RPC messages.""" """Handle POST requests containing JSON-RPC messages."""
writer = self._read_stream_writer writer = self._read_stream_writer
if writer is None: if writer is None:
raise ValueError( raise ValueError("No read stream writer available. Ensure connect() is called first.")
"No read stream writer available. Ensure connect() is called first."
)
try: try:
# Check Accept headers # Check Accept headers
has_json, has_sse = self._check_accept_headers(request) has_json, has_sse = self._check_accept_headers(request)
if not (has_json and has_sse): if not (has_json and has_sse):
response = self._create_error_response( response = self._create_error_response(
( ("Not Acceptable: Client must accept both application/json and " "text/event-stream"),
"Not Acceptable: Client must accept both application/json and "
"text/event-stream"
),
HTTPStatus.NOT_ACCEPTABLE, HTTPStatus.NOT_ACCEPTABLE,
) )
await response(scope, receive, send) await response(scope, receive, send)
@@ -346,9 +321,7 @@ class StreamableHTTPServerTransport:
try: try:
raw_message = json.loads(body) raw_message = json.loads(body)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
response = self._create_error_response( response = self._create_error_response(f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR)
f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR
)
await response(scope, receive, send) await response(scope, receive, send)
return return
@@ -364,10 +337,7 @@ class StreamableHTTPServerTransport:
return return
# Check if this is an initialization request # Check if this is an initialization request
is_initialization_request = ( is_initialization_request = isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
isinstance(message.root, JSONRPCRequest)
and message.root.method == "initialize"
)
if is_initialization_request: if is_initialization_request:
# Check if the server already has an established session # Check if the server already has an established session
@@ -406,9 +376,7 @@ class StreamableHTTPServerTransport:
# Extract the request ID outside the try block for proper scope # Extract the request ID outside the try block for proper scope
request_id = str(message.root.id) request_id = str(message.root.id)
# Register this stream for the request ID # Register this stream for the request ID
self._request_streams[request_id] = anyio.create_memory_object_stream[ self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0)
EventMessage
](0)
request_stream_reader = self._request_streams[request_id][1] request_stream_reader = self._request_streams[request_id][1]
if self.is_json_response_enabled: if self.is_json_response_enabled:
@@ -424,16 +392,12 @@ class StreamableHTTPServerTransport:
# Use similar approach to SSE writer for consistency # Use similar approach to SSE writer for consistency
async for event_message in request_stream_reader: async for event_message in request_stream_reader:
# If it's a response, this is what we're waiting for # If it's a response, this is what we're waiting for
if isinstance( if isinstance(event_message.message.root, JSONRPCResponse | JSONRPCError):
event_message.message.root, JSONRPCResponse | JSONRPCError
):
response_message = event_message.message response_message = event_message.message
break break
# For notifications and request, keep waiting # For notifications and request, keep waiting
else: else:
logger.debug( logger.debug(f"received: {event_message.message.root.method}")
f"received: {event_message.message.root.method}"
)
# At this point we should have a response # At this point we should have a response
if response_message: if response_message:
@@ -442,9 +406,7 @@ class StreamableHTTPServerTransport:
await response(scope, receive, send) await response(scope, receive, send)
else: else:
# This shouldn't happen in normal operation # This shouldn't happen in normal operation
logger.error( logger.error("No response message received before stream closed")
"No response message received before stream closed"
)
response = self._create_error_response( response = self._create_error_response(
"Error processing request: No response received", "Error processing request: No response received",
HTTPStatus.INTERNAL_SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR,
@@ -462,9 +424,7 @@ class StreamableHTTPServerTransport:
await self._clean_up_memory_streams(request_id) await self._clean_up_memory_streams(request_id)
else: else:
# Create SSE stream # Create SSE stream
sse_stream_writer, sse_stream_reader = ( sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
anyio.create_memory_object_stream[dict[str, str]](0)
)
async def sse_writer(): async def sse_writer():
# Get the request ID from the incoming request message # Get the request ID from the incoming request message
@@ -495,11 +455,7 @@ class StreamableHTTPServerTransport:
"Cache-Control": "no-cache, no-transform", "Cache-Control": "no-cache, no-transform",
"Connection": "keep-alive", "Connection": "keep-alive",
"Content-Type": CONTENT_TYPE_SSE, "Content-Type": CONTENT_TYPE_SSE,
**( **({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}),
{MCP_SESSION_ID_HEADER: self.mcp_session_id}
if self.mcp_session_id
else {}
),
} }
response = EventSourceResponse( response = EventSourceResponse(
content=sse_stream_reader, content=sse_stream_reader,
@@ -544,9 +500,7 @@ class StreamableHTTPServerTransport:
""" """
writer = self._read_stream_writer writer = self._read_stream_writer
if writer is None: if writer is None:
raise ValueError( raise ValueError("No read stream writer available. Ensure connect() is called first.")
"No read stream writer available. Ensure connect() is called first."
)
# Validate Accept header - must include text/event-stream # Validate Accept header - must include text/event-stream
_, has_sse = self._check_accept_headers(request) _, has_sse = self._check_accept_headers(request)
@@ -585,17 +539,13 @@ class StreamableHTTPServerTransport:
return return
# Create SSE stream # Create SSE stream
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
dict[str, str]
](0)
async def standalone_sse_writer(): async def standalone_sse_writer():
try: try:
# Create a standalone message stream for server-initiated messages # Create a standalone message stream for server-initiated messages
self._request_streams[GET_STREAM_KEY] = ( self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](0)
anyio.create_memory_object_stream[EventMessage](0)
)
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1] standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]
async with sse_stream_writer, standalone_stream_reader: async with sse_stream_writer, standalone_stream_reader:
@@ -732,9 +682,7 @@ class StreamableHTTPServerTransport:
return True return True
async def _replay_events( async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
self, last_event_id: str, request: Request, send: Send
) -> None:
""" """
Replays events that would have been sent after the specified event ID. Replays events that would have been sent after the specified event ID.
Only used when resumability is enabled. Only used when resumability is enabled.
@@ -754,9 +702,7 @@ class StreamableHTTPServerTransport:
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
# Create SSE stream for replay # Create SSE stream for replay
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
dict[str, str]
](0)
async def replay_sender(): async def replay_sender():
try: try:
@@ -767,15 +713,11 @@ class StreamableHTTPServerTransport:
await sse_stream_writer.send(event_data) await sse_stream_writer.send(event_data)
# Replay past events and get the stream ID # Replay past events and get the stream ID
stream_id = await event_store.replay_events_after( stream_id = await event_store.replay_events_after(last_event_id, send_event)
last_event_id, send_event
)
# If stream ID not in mapping, create it # If stream ID not in mapping, create it
if stream_id and stream_id not in self._request_streams: if stream_id and stream_id not in self._request_streams:
self._request_streams[stream_id] = ( self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0)
anyio.create_memory_object_stream[EventMessage](0)
)
msg_reader = self._request_streams[stream_id][1] msg_reader = self._request_streams[stream_id][1]
# Forward messages to SSE # Forward messages to SSE
@@ -829,12 +771,8 @@ class StreamableHTTPServerTransport:
# Create the memory streams for this connection # Create the memory streams for this connection
read_stream_writer, read_stream = anyio.create_memory_object_stream[ read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
SessionMessage | Exception write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[
SessionMessage
](0)
# Store the streams # Store the streams
self._read_stream_writer = read_stream_writer self._read_stream_writer = read_stream_writer
@@ -867,35 +805,24 @@ class StreamableHTTPServerTransport:
session_message.metadata, session_message.metadata,
ServerMessageMetadata, ServerMessageMetadata,
) )
and session_message.metadata.related_request_id and session_message.metadata.related_request_id is not None
is not None
): ):
target_request_id = str( target_request_id = str(session_message.metadata.related_request_id)
session_message.metadata.related_request_id
)
request_stream_id = ( request_stream_id = target_request_id if target_request_id is not None else GET_STREAM_KEY
target_request_id
if target_request_id is not None
else GET_STREAM_KEY
)
# Store the event if we have an event store, # Store the event if we have an event store,
# regardless of whether a client is connected # regardless of whether a client is connected
# messages will be replayed on the re-connect # messages will be replayed on the re-connect
event_id = None event_id = None
if self._event_store: if self._event_store:
event_id = await self._event_store.store_event( event_id = await self._event_store.store_event(request_stream_id, message)
request_stream_id, message
)
logger.debug(f"Stored {event_id} from {request_stream_id}") logger.debug(f"Stored {event_id} from {request_stream_id}")
if request_stream_id in self._request_streams: if request_stream_id in self._request_streams:
try: try:
# Send both the message and the event ID # Send both the message and the event ID
await self._request_streams[request_stream_id][0].send( await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id))
EventMessage(message, event_id)
)
except ( except (
anyio.BrokenResourceError, anyio.BrokenResourceError,
anyio.ClosedResourceError, anyio.ClosedResourceError,

View File

@@ -165,9 +165,7 @@ class StreamableHTTPSessionManager:
) )
# Start server in a new task # Start server in a new task
async def run_stateless_server( async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED):
*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED
):
async with http_transport.connect() as streams: async with http_transport.connect() as streams:
read_stream, write_stream = streams read_stream, write_stream = streams
task_status.started() task_status.started()
@@ -204,10 +202,7 @@ class StreamableHTTPSessionManager:
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
# Existing session case # Existing session case
if ( if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances:
request_mcp_session_id is not None
and request_mcp_session_id in self._server_instances
):
transport = self._server_instances[request_mcp_session_id] transport = self._server_instances[request_mcp_session_id]
logger.debug("Session already exists, handling request directly") logger.debug("Session already exists, handling request directly")
await transport.handle_request(scope, receive, send) await transport.handle_request(scope, receive, send)
@@ -229,9 +224,7 @@ class StreamableHTTPSessionManager:
logger.info(f"Created new transport with session ID: {new_session_id}") logger.info(f"Created new transport with session ID: {new_session_id}")
# Define the server runner # Define the server runner
async def run_server( async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None:
*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED
) -> None:
async with http_transport.connect() as streams: async with http_transport.connect() as streams:
read_stream, write_stream = streams read_stream, write_stream = streams
task_status.started() task_status.started()

View File

@@ -93,12 +93,8 @@ class StreamingASGITransport(AsyncBaseTransport):
initial_response_ready = anyio.Event() initial_response_ready = anyio.Event()
# Synchronization for streaming response # Synchronization for streaming response
asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[ asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[dict[str, Any]](100)
dict[str, Any] content_send_channel, content_receive_channel = anyio.create_memory_object_stream[bytes](100)
](100)
content_send_channel, content_receive_channel = (
anyio.create_memory_object_stream[bytes](100)
)
# ASGI callables. # ASGI callables.
async def receive() -> dict[str, Any]: async def receive() -> dict[str, Any]:
@@ -124,21 +120,15 @@ class StreamingASGITransport(AsyncBaseTransport):
async def run_app() -> None: async def run_app() -> None:
try: try:
# Cast the receive and send functions to the ASGI types # Cast the receive and send functions to the ASGI types
await self.app( await self.app(cast(Scope, scope), cast(Receive, receive), cast(Send, send))
cast(Scope, scope), cast(Receive, receive), cast(Send, send)
)
except Exception: except Exception:
if self.raise_app_exceptions: if self.raise_app_exceptions:
raise raise
if not response_started: if not response_started:
await asgi_send_channel.send( await asgi_send_channel.send({"type": "http.response.start", "status": 500, "headers": []})
{"type": "http.response.start", "status": 500, "headers": []}
)
await asgi_send_channel.send( await asgi_send_channel.send({"type": "http.response.body", "body": b"", "more_body": False})
{"type": "http.response.body", "body": b"", "more_body": False}
)
finally: finally:
await asgi_send_channel.aclose() await asgi_send_channel.aclose()

View File

@@ -51,9 +51,7 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
try: try:
async with write_stream_reader: async with write_stream_reader:
async for session_message in write_stream_reader: async for session_message in write_stream_reader:
obj = session_message.message.model_dump_json( obj = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
by_alias=True, exclude_none=True
)
await websocket.send_text(obj) await websocket.send_text(obj)
except anyio.ClosedResourceError: except anyio.ClosedResourceError:
await websocket.close() await websocket.close()

View File

@@ -45,9 +45,7 @@ class OAuthClientMetadata(BaseModel):
# token_endpoint_auth_method: this implementation only supports none & # token_endpoint_auth_method: this implementation only supports none &
# client_secret_post; # client_secret_post;
# ie: we do not support client_secret_basic # ie: we do not support client_secret_basic
token_endpoint_auth_method: Literal["none", "client_secret_post"] = ( token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post"
"client_secret_post"
)
# grant_types: this implementation only supports authorization_code & refresh_token # grant_types: this implementation only supports authorization_code & refresh_token
grant_types: list[Literal["authorization_code", "refresh_token"]] = [ grant_types: list[Literal["authorization_code", "refresh_token"]] = [
"authorization_code", "authorization_code",
@@ -84,17 +82,12 @@ class OAuthClientMetadata(BaseModel):
if redirect_uri is not None: if redirect_uri is not None:
# Validate redirect_uri against client's registered redirect URIs # Validate redirect_uri against client's registered redirect URIs
if redirect_uri not in self.redirect_uris: if redirect_uri not in self.redirect_uris:
raise InvalidRedirectUriError( raise InvalidRedirectUriError(f"Redirect URI '{redirect_uri}' not registered for client")
f"Redirect URI '{redirect_uri}' not registered for client"
)
return redirect_uri return redirect_uri
elif len(self.redirect_uris) == 1: elif len(self.redirect_uris) == 1:
return self.redirect_uris[0] return self.redirect_uris[0]
else: else:
raise InvalidRedirectUriError( raise InvalidRedirectUriError("redirect_uri must be specified when client " "has multiple registered URIs")
"redirect_uri must be specified when client "
"has multiple registered URIs"
)
class OAuthClientInformationFull(OAuthClientMetadata): class OAuthClientInformationFull(OAuthClientMetadata):

View File

@@ -11,26 +11,15 @@ import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
import mcp.types as types import mcp.types as types
from mcp.client.session import ( from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
ClientSession,
ListRootsFnT,
LoggingFnT,
MessageHandlerFnT,
SamplingFnT,
)
from mcp.server import Server from mcp.server import Server
from mcp.shared.message import SessionMessage from mcp.shared.message import SessionMessage
MessageStream = tuple[ MessageStream = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
]
@asynccontextmanager @asynccontextmanager
async def create_client_server_memory_streams() -> ( async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageStream, MessageStream], None]:
AsyncGenerator[tuple[MessageStream, MessageStream], None]
):
""" """
Creates a pair of bidirectional memory streams for client-server communication. Creates a pair of bidirectional memory streams for client-server communication.
@@ -39,12 +28,8 @@ async def create_client_server_memory_streams() -> (
(read_stream, write_stream) (read_stream, write_stream)
""" """
# Create streams for both directions # Create streams for both directions
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
SessionMessage | Exception client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage | Exception
](1)
client_streams = (server_to_client_receive, client_to_server_send) client_streams = (server_to_client_receive, client_to_server_send)
server_streams = (client_to_server_receive, server_to_client_send) server_streams = (client_to_server_receive, server_to_client_send)

View File

@@ -20,9 +20,7 @@ class ClientMessageMetadata:
"""Metadata specific to client messages.""" """Metadata specific to client messages."""
resumption_token: ResumptionToken | None = None resumption_token: ResumptionToken | None = None
on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = ( on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = None
None
)
@dataclass @dataclass

View File

@@ -23,22 +23,8 @@ class Progress(BaseModel):
@dataclass @dataclass
class ProgressContext( class ProgressContext(Generic[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]):
Generic[ session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
]
):
session: BaseSession[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
]
progress_token: ProgressToken progress_token: ProgressToken
total: float | None total: float | None
current: float = field(default=0.0, init=False) current: float = field(default=0.0, init=False)
@@ -54,24 +40,12 @@ class ProgressContext(
@contextmanager @contextmanager
def progress( def progress(
ctx: RequestContext[ ctx: RequestContext[
BaseSession[ BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
],
LifespanContextT, LifespanContextT,
], ],
total: float | None = None, total: float | None = None,
) -> Generator[ ) -> Generator[
ProgressContext[ ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
],
None, None,
]: ]:
if ctx.meta is None or ctx.meta.progressToken is None: if ctx.meta is None or ctx.meta.progressToken is None:

View File

@@ -38,9 +38,7 @@ SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar( ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
"ReceiveNotificationT", ClientNotification, ServerNotification
)
RequestId = str | int RequestId = str | int
@@ -48,9 +46,7 @@ RequestId = str | int
class ProgressFnT(Protocol): class ProgressFnT(Protocol):
"""Protocol for progress notification callbacks.""" """Protocol for progress notification callbacks."""
async def __call__( async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ...
self, progress: float, total: float | None, message: str | None
) -> None: ...
class RequestResponder(Generic[ReceiveRequestT, SendResultT]): class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
@@ -177,9 +173,7 @@ class BaseSession(
messages when entered. messages when entered.
""" """
_response_streams: dict[ _response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]]
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
]
_request_id: int _request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_progress_callbacks: dict[RequestId, ProgressFnT] _progress_callbacks: dict[RequestId, ProgressFnT]
@@ -242,9 +236,7 @@ class BaseSession(
request_id = self._request_id request_id = self._request_id
self._request_id = request_id + 1 self._request_id = request_id + 1
response_stream, response_stream_reader = anyio.create_memory_object_stream[ response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1)
JSONRPCResponse | JSONRPCError
](1)
self._response_streams[request_id] = response_stream self._response_streams[request_id] = response_stream
# Set up progress token if progress callback is provided # Set up progress token if progress callback is provided
@@ -266,11 +258,7 @@ class BaseSession(
**request_data, **request_data,
) )
await self._write_stream.send( await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
SessionMessage(
message=JSONRPCMessage(jsonrpc_request), metadata=metadata
)
)
# request read timeout takes precedence over session read timeout # request read timeout takes precedence over session read timeout
timeout = None timeout = None
@@ -322,15 +310,11 @@ class BaseSession(
) )
session_message = SessionMessage( session_message = SessionMessage(
message=JSONRPCMessage(jsonrpc_notification), message=JSONRPCMessage(jsonrpc_notification),
metadata=ServerMessageMetadata(related_request_id=related_request_id) metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
if related_request_id
else None,
) )
await self._write_stream.send(session_message) await self._write_stream.send(session_message)
async def _send_response( async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
self, request_id: RequestId, response: SendResultT | ErrorData
) -> None:
if isinstance(response, ErrorData): if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
@@ -339,9 +323,7 @@ class BaseSession(
jsonrpc_response = JSONRPCResponse( jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
id=request_id, id=request_id,
result=response.model_dump( result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
by_alias=True, mode="json", exclude_none=True
),
) )
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
await self._write_stream.send(session_message) await self._write_stream.send(session_message)
@@ -357,19 +339,14 @@ class BaseSession(
elif isinstance(message.message.root, JSONRPCRequest): elif isinstance(message.message.root, JSONRPCRequest):
try: try:
validated_request = self._receive_request_type.model_validate( validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump( message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
by_alias=True, mode="json", exclude_none=True
)
) )
responder = RequestResponder( responder = RequestResponder(
request_id=message.message.root.id, request_id=message.message.root.id,
request_meta=validated_request.root.params.meta request_meta=validated_request.root.params.meta if validated_request.root.params else None,
if validated_request.root.params
else None,
request=validated_request, request=validated_request,
session=self, session=self,
on_complete=lambda r: self._in_flight.pop( on_complete=lambda r: self._in_flight.pop(r.request_id, None),
r.request_id, None),
message_metadata=message.metadata, message_metadata=message.metadata,
) )
self._in_flight[responder.request_id] = responder self._in_flight[responder.request_id] = responder
@@ -381,9 +358,7 @@ class BaseSession(
# For request validation errors, send a proper JSON-RPC error # For request validation errors, send a proper JSON-RPC error
# response instead of crashing the server # response instead of crashing the server
logging.warning(f"Failed to validate request: {e}") logging.warning(f"Failed to validate request: {e}")
logging.debug( logging.debug(f"Message that failed validation: {message.message.root}")
f"Message that failed validation: {message.message.root}"
)
error_response = JSONRPCError( error_response = JSONRPCError(
jsonrpc="2.0", jsonrpc="2.0",
id=message.message.root.id, id=message.message.root.id,
@@ -393,16 +368,13 @@ class BaseSession(
data="", data="",
), ),
) )
session_message = SessionMessage( session_message = SessionMessage(message=JSONRPCMessage(error_response))
message=JSONRPCMessage(error_response))
await self._write_stream.send(session_message) await self._write_stream.send(session_message)
elif isinstance(message.message.root, JSONRPCNotification): elif isinstance(message.message.root, JSONRPCNotification):
try: try:
notification = self._receive_notification_type.model_validate( notification = self._receive_notification_type.model_validate(
message.message.root.model_dump( message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
by_alias=True, mode="json", exclude_none=True
)
) )
# Handle cancellation notifications # Handle cancellation notifications
if isinstance(notification.root, CancelledNotification): if isinstance(notification.root, CancelledNotification):
@@ -427,8 +399,7 @@ class BaseSession(
except Exception as e: except Exception as e:
# For other validation errors, log and continue # For other validation errors, log and continue
logging.warning( logging.warning(
f"Failed to validate notification: {e}. " f"Failed to validate notification: {e}. " f"Message was: {message.message.root}"
f"Message was: {message.message.root}"
) )
else: # Response or error else: # Response or error
stream = self._response_streams.pop(message.message.root.id, None) stream = self._response_streams.pop(message.message.root.id, None)
@@ -436,10 +407,7 @@ class BaseSession(
await stream.send(message.message.root) await stream.send(message.message.root)
else: else:
await self._handle_incoming( await self._handle_incoming(
RuntimeError( RuntimeError("Received response with an unknown " f"request ID: {message}")
"Received response with an unknown "
f"request ID: {message}"
)
) )
# after the read stream is closed, we need to send errors # after the read stream is closed, we need to send errors
@@ -450,9 +418,7 @@ class BaseSession(
await stream.aclose() await stream.aclose()
self._response_streams.clear() self._response_streams.clear()
async def _received_request( async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
self, responder: RequestResponder[ReceiveRequestT, SendResultT]
) -> None:
""" """
Can be overridden by subclasses to handle a request without needing to Can be overridden by subclasses to handle a request without needing to
listen on the message stream. listen on the message stream.
@@ -481,9 +447,7 @@ class BaseSession(
async def _handle_incoming( async def _handle_incoming(
self, self,
req: RequestResponder[ReceiveRequestT, SendResultT] req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
| ReceiveNotificationT
| Exception,
) -> None: ) -> None:
"""A generic handler for incoming messages. Overwritten by subclasses.""" """A generic handler for incoming messages. Overwritten by subclasses."""
pass pass

View File

@@ -1,12 +1,5 @@
from collections.abc import Callable from collections.abc import Callable
from typing import ( from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar
Annotated,
Any,
Generic,
Literal,
TypeAlias,
TypeVar,
)
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
from pydantic.networks import AnyUrl, UrlConstraints from pydantic.networks import AnyUrl, UrlConstraints
@@ -73,9 +66,7 @@ class NotificationParams(BaseModel):
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None) RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None)
NotificationParamsT = TypeVar( NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams | dict[str, Any] | None)
"NotificationParamsT", bound=NotificationParams | dict[str, Any] | None
)
MethodT = TypeVar("MethodT", bound=str) MethodT = TypeVar("MethodT", bound=str)
@@ -87,9 +78,7 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class PaginatedRequest( class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]):
Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]
):
"""Base class for paginated requests, """Base class for paginated requests,
matching the schema's PaginatedRequest interface.""" matching the schema's PaginatedRequest interface."""
@@ -191,9 +180,7 @@ class JSONRPCError(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class JSONRPCMessage( class JSONRPCMessage(RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]):
RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]
):
pass pass
@@ -314,9 +301,7 @@ class InitializeResult(Result):
"""Instructions describing how to use the server and its features.""" """Instructions describing how to use the server and its features."""
class InitializedNotification( class InitializedNotification(Notification[NotificationParams | None, Literal["notifications/initialized"]]):
Notification[NotificationParams | None, Literal["notifications/initialized"]]
):
""" """
This notification is sent from the client to the server after initialization has This notification is sent from the client to the server after initialization has
finished. finished.
@@ -359,9 +344,7 @@ class ProgressNotificationParams(NotificationParams):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class ProgressNotification( class ProgressNotification(Notification[ProgressNotificationParams, Literal["notifications/progress"]]):
Notification[ProgressNotificationParams, Literal["notifications/progress"]]
):
""" """
An out-of-band notification used to inform the receiver of a progress update for a An out-of-band notification used to inform the receiver of a progress update for a
long-running request. long-running request.
@@ -432,9 +415,7 @@ class ListResourcesResult(PaginatedResult):
resources: list[Resource] resources: list[Resource]
class ListResourceTemplatesRequest( class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]):
PaginatedRequest[Literal["resources/templates/list"]]
):
"""Sent from the client to request a list of resource templates the server has.""" """Sent from the client to request a list of resource templates the server has."""
method: Literal["resources/templates/list"] method: Literal["resources/templates/list"]
@@ -457,9 +438,7 @@ class ReadResourceRequestParams(RequestParams):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class ReadResourceRequest( class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]):
Request[ReadResourceRequestParams, Literal["resources/read"]]
):
"""Sent from the client to the server, to read a specific resource URI.""" """Sent from the client to the server, to read a specific resource URI."""
method: Literal["resources/read"] method: Literal["resources/read"]
@@ -500,9 +479,7 @@ class ReadResourceResult(Result):
class ResourceListChangedNotification( class ResourceListChangedNotification(
Notification[ Notification[NotificationParams | None, Literal["notifications/resources/list_changed"]]
NotificationParams | None, Literal["notifications/resources/list_changed"]
]
): ):
""" """
An optional notification from the server to the client, informing it that the list An optional notification from the server to the client, informing it that the list
@@ -542,9 +519,7 @@ class UnsubscribeRequestParams(RequestParams):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class UnsubscribeRequest( class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]]):
Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]]
):
""" """
Sent from the client to request cancellation of resources/updated notifications from Sent from the client to request cancellation of resources/updated notifications from
the server. the server.
@@ -566,9 +541,7 @@ class ResourceUpdatedNotificationParams(NotificationParams):
class ResourceUpdatedNotification( class ResourceUpdatedNotification(
Notification[ Notification[ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"]]
ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"]
]
): ):
""" """
A notification from the server to the client, informing it that a resource has A notification from the server to the client, informing it that a resource has
@@ -711,9 +684,7 @@ class GetPromptResult(Result):
class PromptListChangedNotification( class PromptListChangedNotification(
Notification[ Notification[NotificationParams | None, Literal["notifications/prompts/list_changed"]]
NotificationParams | None, Literal["notifications/prompts/list_changed"]
]
): ):
""" """
An optional notification from the server to the client, informing it that the list An optional notification from the server to the client, informing it that the list
@@ -820,9 +791,7 @@ class CallToolResult(Result):
isError: bool = False isError: bool = False
class ToolListChangedNotification( class ToolListChangedNotification(Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]]):
Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]]
):
""" """
An optional notification from the server to the client, informing it that the list An optional notification from the server to the client, informing it that the list
of tools it offers has changed. of tools it offers has changed.
@@ -832,9 +801,7 @@ class ToolListChangedNotification(
params: NotificationParams | None = None params: NotificationParams | None = None
LoggingLevel = Literal[ LoggingLevel = Literal["debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"]
"debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"
]
class SetLevelRequestParams(RequestParams): class SetLevelRequestParams(RequestParams):
@@ -867,9 +834,7 @@ class LoggingMessageNotificationParams(NotificationParams):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class LoggingMessageNotification( class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]):
Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]
):
"""Notification of a log message passed from server to client.""" """Notification of a log message passed from server to client."""
method: Literal["notifications/message"] method: Literal["notifications/message"]
@@ -964,9 +929,7 @@ class CreateMessageRequestParams(RequestParams):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class CreateMessageRequest( class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]):
Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]
):
"""A request from the server to sample an LLM via the client.""" """A request from the server to sample an LLM via the client."""
method: Literal["sampling/createMessage"] method: Literal["sampling/createMessage"]
@@ -1123,9 +1086,7 @@ class CancelledNotificationParams(NotificationParams):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class CancelledNotification( class CancelledNotification(Notification[CancelledNotificationParams, Literal["notifications/cancelled"]]):
Notification[CancelledNotificationParams, Literal["notifications/cancelled"]]
):
""" """
This notification can be sent by either side to indicate that it is canceling a This notification can be sent by either side to indicate that it is canceling a
previously-issued request. previously-issued request.
@@ -1156,12 +1117,7 @@ class ClientRequest(
class ClientNotification( class ClientNotification(
RootModel[ RootModel[CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification]
CancelledNotification
| ProgressNotification
| InitializedNotification
| RootsListChangedNotification
]
): ):
pass pass

View File

@@ -49,8 +49,7 @@ class StreamSpyCollection:
return [ return [
req.message.root req.message.root
for req in self.client.sent_messages for req in self.client.sent_messages
if isinstance(req.message.root, JSONRPCRequest) if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method)
and (method is None or req.message.root.method == method)
] ]
def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]: def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
@@ -58,13 +57,10 @@ class StreamSpyCollection:
return [ return [
req.message.root req.message.root
for req in self.server.sent_messages for req in self.server.sent_messages
if isinstance(req.message.root, JSONRPCRequest) if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method)
and (method is None or req.message.root.method == method)
] ]
def get_client_notifications( def get_client_notifications(self, method: str | None = None) -> list[JSONRPCNotification]:
self, method: str | None = None
) -> list[JSONRPCNotification]:
"""Get client-sent notifications, optionally filtered by method.""" """Get client-sent notifications, optionally filtered by method."""
return [ return [
notif.message.root notif.message.root
@@ -73,9 +69,7 @@ class StreamSpyCollection:
and (method is None or notif.message.root.method == method) and (method is None or notif.message.root.method == method)
] ]
def get_server_notifications( def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNotification]:
self, method: str | None = None
) -> list[JSONRPCNotification]:
"""Get server-sent notifications, optionally filtered by method.""" """Get server-sent notifications, optionally filtered by method."""
return [ return [
notif.message.root notif.message.root
@@ -133,9 +127,7 @@ def stream_spy():
yield (client_read, spy_client_write), (server_read, spy_server_write) yield (client_read, spy_client_write), (server_read, spy_server_write)
# Apply the patch for the duration of the test # Apply the patch for the duration of the test
with patch( with patch("mcp.shared.memory.create_client_server_memory_streams", patched_create_streams):
"mcp.shared.memory.create_client_server_memory_streams", patched_create_streams
):
# Return a collection with helper methods # Return a collection with helper methods
def get_spy_collection() -> StreamSpyCollection: def get_spy_collection() -> StreamSpyCollection:
assert client_spy is not None, "client_spy was not initialized" assert client_spy is not None, "client_spy was not initialized"

View File

@@ -134,9 +134,7 @@ class TestOAuthClientProvider:
assert len(verifier) == 128 assert len(verifier) == 128
# Check charset (RFC 7636: A-Z, a-z, 0-9, "-", ".", "_", "~") # Check charset (RFC 7636: A-Z, a-z, 0-9, "-", ".", "_", "~")
allowed_chars = set( allowed_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~")
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
)
assert set(verifier) <= allowed_chars assert set(verifier) <= allowed_chars
# Check uniqueness (generate multiple and ensure they're different) # Check uniqueness (generate multiple and ensure they're different)
@@ -151,9 +149,7 @@ class TestOAuthClientProvider:
# Manually calculate expected challenge # Manually calculate expected challenge
expected_digest = hashlib.sha256(verifier.encode()).digest() expected_digest = hashlib.sha256(verifier.encode()).digest()
expected_challenge = ( expected_challenge = base64.urlsafe_b64encode(expected_digest).decode().rstrip("=")
base64.urlsafe_b64encode(expected_digest).decode().rstrip("=")
)
assert challenge == expected_challenge assert challenge == expected_challenge
@@ -166,29 +162,19 @@ class TestOAuthClientProvider:
async def test_get_authorization_base_url(self, oauth_provider): async def test_get_authorization_base_url(self, oauth_provider):
"""Test authorization base URL extraction.""" """Test authorization base URL extraction."""
# Test with path # Test with path
assert ( assert oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp") == "https://api.example.com"
oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp")
== "https://api.example.com"
)
# Test with no path # Test with no path
assert ( assert oauth_provider._get_authorization_base_url("https://api.example.com") == "https://api.example.com"
oauth_provider._get_authorization_base_url("https://api.example.com")
== "https://api.example.com"
)
# Test with port # Test with port
assert ( assert (
oauth_provider._get_authorization_base_url( oauth_provider._get_authorization_base_url("https://api.example.com:8080/path/to/mcp")
"https://api.example.com:8080/path/to/mcp"
)
== "https://api.example.com:8080" == "https://api.example.com:8080"
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_discover_oauth_metadata_success( async def test_discover_oauth_metadata_success(self, oauth_provider, oauth_metadata):
self, oauth_provider, oauth_metadata
):
"""Test successful OAuth metadata discovery.""" """Test successful OAuth metadata discovery."""
metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json") metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json")
@@ -201,23 +187,16 @@ class TestOAuthClientProvider:
mock_response.json.return_value = metadata_response mock_response.json.return_value = metadata_response
mock_client.get.return_value = mock_response mock_client.get.return_value = mock_response
result = await oauth_provider._discover_oauth_metadata( result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp")
"https://api.example.com/v1/mcp"
)
assert result is not None assert result is not None
assert ( assert result.authorization_endpoint == oauth_metadata.authorization_endpoint
result.authorization_endpoint == oauth_metadata.authorization_endpoint
)
assert result.token_endpoint == oauth_metadata.token_endpoint assert result.token_endpoint == oauth_metadata.token_endpoint
# Verify correct URL was called # Verify correct URL was called
mock_client.get.assert_called_once() mock_client.get.assert_called_once()
call_args = mock_client.get.call_args[0] call_args = mock_client.get.call_args[0]
assert ( assert call_args[0] == "https://api.example.com/.well-known/oauth-authorization-server"
call_args[0]
== "https://api.example.com/.well-known/oauth-authorization-server"
)
@pytest.mark.anyio @pytest.mark.anyio
async def test_discover_oauth_metadata_not_found(self, oauth_provider): async def test_discover_oauth_metadata_not_found(self, oauth_provider):
@@ -230,16 +209,12 @@ class TestOAuthClientProvider:
mock_response.status_code = 404 mock_response.status_code = 404
mock_client.get.return_value = mock_response mock_client.get.return_value = mock_response
result = await oauth_provider._discover_oauth_metadata( result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp")
"https://api.example.com/v1/mcp"
)
assert result is None assert result is None
@pytest.mark.anyio @pytest.mark.anyio
async def test_discover_oauth_metadata_cors_fallback( async def test_discover_oauth_metadata_cors_fallback(self, oauth_provider, oauth_metadata):
self, oauth_provider, oauth_metadata
):
"""Test OAuth metadata discovery with CORS fallback.""" """Test OAuth metadata discovery with CORS fallback."""
metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json") metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json")
@@ -257,17 +232,13 @@ class TestOAuthClientProvider:
mock_response_success, # Second call succeeds mock_response_success, # Second call succeeds
] ]
result = await oauth_provider._discover_oauth_metadata( result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp")
"https://api.example.com/v1/mcp"
)
assert result is not None assert result is not None
assert mock_client.get.call_count == 2 assert mock_client.get.call_count == 2
@pytest.mark.anyio @pytest.mark.anyio
async def test_register_oauth_client_success( async def test_register_oauth_client_success(self, oauth_provider, oauth_metadata, oauth_client_info):
self, oauth_provider, oauth_metadata, oauth_client_info
):
"""Test successful OAuth client registration.""" """Test successful OAuth client registration."""
registration_response = oauth_client_info.model_dump(by_alias=True, mode="json") registration_response = oauth_client_info.model_dump(by_alias=True, mode="json")
@@ -295,9 +266,7 @@ class TestOAuthClientProvider:
assert call_args[0][0] == str(oauth_metadata.registration_endpoint) assert call_args[0][0] == str(oauth_metadata.registration_endpoint)
@pytest.mark.anyio @pytest.mark.anyio
async def test_register_oauth_client_fallback_endpoint( async def test_register_oauth_client_fallback_endpoint(self, oauth_provider, oauth_client_info):
self, oauth_provider, oauth_client_info
):
"""Test OAuth client registration with fallback endpoint.""" """Test OAuth client registration with fallback endpoint."""
registration_response = oauth_client_info.model_dump(by_alias=True, mode="json") registration_response = oauth_client_info.model_dump(by_alias=True, mode="json")
@@ -311,9 +280,7 @@ class TestOAuthClientProvider:
mock_client.post.return_value = mock_response mock_client.post.return_value = mock_response
# Mock metadata discovery to return None (fallback) # Mock metadata discovery to return None (fallback)
with patch.object( with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None):
oauth_provider, "_discover_oauth_metadata", return_value=None
):
result = await oauth_provider._register_oauth_client( result = await oauth_provider._register_oauth_client(
"https://api.example.com/v1/mcp", "https://api.example.com/v1/mcp",
oauth_provider.client_metadata, oauth_provider.client_metadata,
@@ -340,9 +307,7 @@ class TestOAuthClientProvider:
mock_client.post.return_value = mock_response mock_client.post.return_value = mock_response
# Mock metadata discovery to return None (fallback) # Mock metadata discovery to return None (fallback)
with patch.object( with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None):
oauth_provider, "_discover_oauth_metadata", return_value=None
):
with pytest.raises(httpx.HTTPStatusError): with pytest.raises(httpx.HTTPStatusError):
await oauth_provider._register_oauth_client( await oauth_provider._register_oauth_client(
"https://api.example.com/v1/mcp", "https://api.example.com/v1/mcp",
@@ -406,9 +371,7 @@ class TestOAuthClientProvider:
await oauth_provider._validate_token_scopes(token) await oauth_provider._validate_token_scopes(token)
@pytest.mark.anyio @pytest.mark.anyio
async def test_validate_token_scopes_unauthorized( async def test_validate_token_scopes_unauthorized(self, oauth_provider, client_metadata):
self, oauth_provider, client_metadata
):
"""Test scope validation with unauthorized scopes.""" """Test scope validation with unauthorized scopes."""
oauth_provider.client_metadata = client_metadata oauth_provider.client_metadata = client_metadata
token = OAuthToken( token = OAuthToken(
@@ -436,9 +399,7 @@ class TestOAuthClientProvider:
await oauth_provider._validate_token_scopes(token) await oauth_provider._validate_token_scopes(token)
@pytest.mark.anyio @pytest.mark.anyio
async def test_initialize( async def test_initialize(self, oauth_provider, mock_storage, oauth_token, oauth_client_info):
self, oauth_provider, mock_storage, oauth_token, oauth_client_info
):
"""Test initialization loading from storage.""" """Test initialization loading from storage."""
mock_storage._tokens = oauth_token mock_storage._tokens = oauth_token
mock_storage._client_info = oauth_client_info mock_storage._client_info = oauth_client_info
@@ -449,9 +410,7 @@ class TestOAuthClientProvider:
assert oauth_provider._client_info == oauth_client_info assert oauth_provider._client_info == oauth_client_info
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_or_register_client_existing( async def test_get_or_register_client_existing(self, oauth_provider, oauth_client_info):
self, oauth_provider, oauth_client_info
):
"""Test getting existing client info.""" """Test getting existing client info."""
oauth_provider._client_info = oauth_client_info oauth_provider._client_info = oauth_client_info
@@ -460,13 +419,9 @@ class TestOAuthClientProvider:
assert result == oauth_client_info assert result == oauth_client_info
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_or_register_client_register_new( async def test_get_or_register_client_register_new(self, oauth_provider, oauth_client_info):
self, oauth_provider, oauth_client_info
):
"""Test registering new client.""" """Test registering new client."""
with patch.object( with patch.object(oauth_provider, "_register_oauth_client", return_value=oauth_client_info) as mock_register:
oauth_provider, "_register_oauth_client", return_value=oauth_client_info
) as mock_register:
result = await oauth_provider._get_or_register_client() result = await oauth_provider._get_or_register_client()
assert result == oauth_client_info assert result == oauth_client_info
@@ -474,9 +429,7 @@ class TestOAuthClientProvider:
mock_register.assert_called_once() mock_register.assert_called_once()
@pytest.mark.anyio @pytest.mark.anyio
async def test_exchange_code_for_token_success( async def test_exchange_code_for_token_success(self, oauth_provider, oauth_client_info, oauth_token):
self, oauth_provider, oauth_client_info, oauth_token
):
"""Test successful code exchange for token.""" """Test successful code exchange for token."""
oauth_provider._code_verifier = "test_verifier" oauth_provider._code_verifier = "test_verifier"
token_response = oauth_token.model_dump(by_alias=True, mode="json") token_response = oauth_token.model_dump(by_alias=True, mode="json")
@@ -490,23 +443,14 @@ class TestOAuthClientProvider:
mock_response.json.return_value = token_response mock_response.json.return_value = token_response
mock_client.post.return_value = mock_response mock_client.post.return_value = mock_response
with patch.object( with patch.object(oauth_provider, "_validate_token_scopes") as mock_validate:
oauth_provider, "_validate_token_scopes" await oauth_provider._exchange_code_for_token("test_auth_code", oauth_client_info)
) as mock_validate:
await oauth_provider._exchange_code_for_token(
"test_auth_code", oauth_client_info
)
assert ( assert oauth_provider._current_tokens.access_token == oauth_token.access_token
oauth_provider._current_tokens.access_token
== oauth_token.access_token
)
mock_validate.assert_called_once() mock_validate.assert_called_once()
@pytest.mark.anyio @pytest.mark.anyio
async def test_exchange_code_for_token_failure( async def test_exchange_code_for_token_failure(self, oauth_provider, oauth_client_info):
self, oauth_provider, oauth_client_info
):
"""Test failed code exchange for token.""" """Test failed code exchange for token."""
oauth_provider._code_verifier = "test_verifier" oauth_provider._code_verifier = "test_verifier"
@@ -520,14 +464,10 @@ class TestOAuthClientProvider:
mock_client.post.return_value = mock_response mock_client.post.return_value = mock_response
with pytest.raises(Exception, match="Token exchange failed"): with pytest.raises(Exception, match="Token exchange failed"):
await oauth_provider._exchange_code_for_token( await oauth_provider._exchange_code_for_token("invalid_auth_code", oauth_client_info)
"invalid_auth_code", oauth_client_info
)
@pytest.mark.anyio @pytest.mark.anyio
async def test_refresh_access_token_success( async def test_refresh_access_token_success(self, oauth_provider, oauth_client_info, oauth_token):
self, oauth_provider, oauth_client_info, oauth_token
):
"""Test successful token refresh.""" """Test successful token refresh."""
oauth_provider._current_tokens = oauth_token oauth_provider._current_tokens = oauth_token
oauth_provider._client_info = oauth_client_info oauth_provider._client_info = oauth_client_info
@@ -550,16 +490,11 @@ class TestOAuthClientProvider:
mock_response.json.return_value = token_response mock_response.json.return_value = token_response
mock_client.post.return_value = mock_response mock_client.post.return_value = mock_response
with patch.object( with patch.object(oauth_provider, "_validate_token_scopes") as mock_validate:
oauth_provider, "_validate_token_scopes"
) as mock_validate:
result = await oauth_provider._refresh_access_token() result = await oauth_provider._refresh_access_token()
assert result is True assert result is True
assert ( assert oauth_provider._current_tokens.access_token == new_token.access_token
oauth_provider._current_tokens.access_token
== new_token.access_token
)
mock_validate.assert_called_once() mock_validate.assert_called_once()
@pytest.mark.anyio @pytest.mark.anyio
@@ -575,9 +510,7 @@ class TestOAuthClientProvider:
assert result is False assert result is False
@pytest.mark.anyio @pytest.mark.anyio
async def test_refresh_access_token_failure( async def test_refresh_access_token_failure(self, oauth_provider, oauth_client_info, oauth_token):
self, oauth_provider, oauth_client_info, oauth_token
):
"""Test failed token refresh.""" """Test failed token refresh."""
oauth_provider._current_tokens = oauth_token oauth_provider._current_tokens = oauth_token
oauth_provider._client_info = oauth_client_info oauth_provider._client_info = oauth_client_info
@@ -594,9 +527,7 @@ class TestOAuthClientProvider:
assert result is False assert result is False
@pytest.mark.anyio @pytest.mark.anyio
async def test_perform_oauth_flow_success( async def test_perform_oauth_flow_success(self, oauth_provider, oauth_metadata, oauth_client_info):
self, oauth_provider, oauth_metadata, oauth_client_info
):
"""Test successful OAuth flow.""" """Test successful OAuth flow."""
oauth_provider._metadata = oauth_metadata oauth_provider._metadata = oauth_metadata
oauth_provider._client_info = oauth_client_info oauth_provider._client_info = oauth_client_info
@@ -640,9 +571,7 @@ class TestOAuthClientProvider:
mock_exchange.assert_called_once_with("test_auth_code", oauth_client_info) mock_exchange.assert_called_once_with("test_auth_code", oauth_client_info)
@pytest.mark.anyio @pytest.mark.anyio
async def test_perform_oauth_flow_state_mismatch( async def test_perform_oauth_flow_state_mismatch(self, oauth_provider, oauth_metadata, oauth_client_info):
self, oauth_provider, oauth_metadata, oauth_client_info
):
"""Test OAuth flow with state parameter mismatch.""" """Test OAuth flow with state parameter mismatch."""
oauth_provider._metadata = oauth_metadata oauth_provider._metadata = oauth_metadata
oauth_provider._client_info = oauth_client_info oauth_provider._client_info = oauth_client_info
@@ -678,9 +607,7 @@ class TestOAuthClientProvider:
oauth_provider._current_tokens = oauth_token oauth_provider._current_tokens = oauth_token
oauth_provider._token_expiry_time = time.time() - 3600 # Expired oauth_provider._token_expiry_time = time.time() - 3600 # Expired
with patch.object( with patch.object(oauth_provider, "_refresh_access_token", return_value=True) as mock_refresh:
oauth_provider, "_refresh_access_token", return_value=True
) as mock_refresh:
await oauth_provider.ensure_token() await oauth_provider.ensure_token()
mock_refresh.assert_called_once() mock_refresh.assert_called_once()
@@ -707,10 +634,7 @@ class TestOAuthClientProvider:
auth_flow = oauth_provider.async_auth_flow(request) auth_flow = oauth_provider.async_auth_flow(request)
updated_request = await auth_flow.__anext__() updated_request = await auth_flow.__anext__()
assert ( assert updated_request.headers["Authorization"] == f"Bearer {oauth_token.access_token}"
updated_request.headers["Authorization"]
== f"Bearer {oauth_token.access_token}"
)
# Send mock response # Send mock response
try: try:
@@ -761,9 +685,7 @@ class TestOAuthClientProvider:
assert "Authorization" not in updated_request.headers assert "Authorization" not in updated_request.headers
@pytest.mark.anyio @pytest.mark.anyio
async def test_scope_priority_client_metadata_first( async def test_scope_priority_client_metadata_first(self, oauth_provider, oauth_client_info):
self, oauth_provider, oauth_client_info
):
"""Test that client metadata scope takes priority.""" """Test that client metadata scope takes priority."""
oauth_provider.client_metadata.scope = "read write" oauth_provider.client_metadata.scope = "read write"
oauth_provider._client_info = oauth_client_info oauth_provider._client_info = oauth_client_info
@@ -782,18 +704,13 @@ class TestOAuthClientProvider:
# Apply scope logic from _perform_oauth_flow # Apply scope logic from _perform_oauth_flow
if oauth_provider.client_metadata.scope: if oauth_provider.client_metadata.scope:
auth_params["scope"] = oauth_provider.client_metadata.scope auth_params["scope"] = oauth_provider.client_metadata.scope
elif ( elif hasattr(oauth_provider._client_info, "scope") and oauth_provider._client_info.scope:
hasattr(oauth_provider._client_info, "scope")
and oauth_provider._client_info.scope
):
auth_params["scope"] = oauth_provider._client_info.scope auth_params["scope"] = oauth_provider._client_info.scope
assert auth_params["scope"] == "read write" assert auth_params["scope"] == "read write"
@pytest.mark.anyio @pytest.mark.anyio
async def test_scope_priority_no_client_metadata_scope( async def test_scope_priority_no_client_metadata_scope(self, oauth_provider, oauth_client_info):
self, oauth_provider, oauth_client_info
):
"""Test that no scope parameter is set when client metadata has no scope.""" """Test that no scope parameter is set when client metadata has no scope."""
oauth_provider.client_metadata.scope = None oauth_provider.client_metadata.scope = None
oauth_provider._client_info = oauth_client_info oauth_provider._client_info = oauth_client_info
@@ -837,10 +754,7 @@ class TestOAuthClientProvider:
# Apply scope logic from _perform_oauth_flow # Apply scope logic from _perform_oauth_flow
if oauth_provider.client_metadata.scope: if oauth_provider.client_metadata.scope:
auth_params["scope"] = oauth_provider.client_metadata.scope auth_params["scope"] = oauth_provider.client_metadata.scope
elif ( elif hasattr(oauth_provider._client_info, "scope") and oauth_provider._client_info.scope:
hasattr(oauth_provider._client_info, "scope")
and oauth_provider._client_info.scope
):
auth_params["scope"] = oauth_provider._client_info.scope auth_params["scope"] = oauth_provider._client_info.scope
# No scope should be set # No scope should be set
@@ -866,9 +780,7 @@ class TestOAuthClientProvider:
oauth_provider.redirect_handler = mock_redirect_handler oauth_provider.redirect_handler = mock_redirect_handler
# Patch secrets.compare_digest to verify it's being called # Patch secrets.compare_digest to verify it's being called
with patch( with patch("mcp.client.auth.secrets.compare_digest", return_value=False) as mock_compare:
"mcp.client.auth.secrets.compare_digest", return_value=False
) as mock_compare:
with pytest.raises(Exception, match="State parameter mismatch"): with pytest.raises(Exception, match="State parameter mismatch"):
await oauth_provider._perform_oauth_flow() await oauth_provider._perform_oauth_flow()
@@ -876,9 +788,7 @@ class TestOAuthClientProvider:
mock_compare.assert_called_once() mock_compare.assert_called_once()
@pytest.mark.anyio @pytest.mark.anyio
async def test_state_parameter_validation_none_state( async def test_state_parameter_validation_none_state(self, oauth_provider, oauth_metadata, oauth_client_info):
self, oauth_provider, oauth_metadata, oauth_client_info
):
"""Test that None state is handled correctly.""" """Test that None state is handled correctly."""
oauth_provider._metadata = oauth_metadata oauth_provider._metadata = oauth_metadata
oauth_provider._client_info = oauth_client_info oauth_provider._client_info = oauth_client_info
@@ -913,9 +823,7 @@ class TestOAuthClientProvider:
mock_client.post.return_value = mock_response mock_client.post.return_value = mock_response
with pytest.raises(Exception, match="Token exchange failed"): with pytest.raises(Exception, match="Token exchange failed"):
await oauth_provider._exchange_code_for_token( await oauth_provider._exchange_code_for_token("invalid_auth_code", oauth_client_info)
"invalid_auth_code", oauth_client_info
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -968,9 +876,7 @@ def test_build_metadata(
metadata = build_metadata( metadata = build_metadata(
issuer_url=AnyHttpUrl(issuer_url), issuer_url=AnyHttpUrl(issuer_url),
service_documentation_url=AnyHttpUrl(service_documentation_url), service_documentation_url=AnyHttpUrl(service_documentation_url),
client_registration_options=ClientRegistrationOptions( client_registration_options=ClientRegistrationOptions(enabled=True, valid_scopes=["read", "write", "admin"]),
enabled=True, valid_scopes=["read", "write", "admin"]
),
revocation_options=RevocationOptions(enabled=True), revocation_options=RevocationOptions(enabled=True),
) )

View File

@@ -44,9 +44,7 @@ def test_command_execution(mock_config_path: Path):
test_args = [command] + args + ["--help"] test_args = [command] + args + ["--help"]
result = subprocess.run( result = subprocess.run(test_args, capture_output=True, text=True, timeout=5, check=False)
test_args, capture_output=True, text=True, timeout=5, check=False
)
assert result.returncode == 0 assert result.returncode == 0
assert "usage" in result.stdout.lower() assert "usage" in result.stdout.lower()

View File

@@ -182,9 +182,7 @@ async def test_list_resource_templates_cursor_parameter(stream_spy):
# Test without cursor parameter (omitted) # Test without cursor parameter (omitted)
_ = await client_session.list_resource_templates() _ = await client_session.list_resource_templates()
list_templates_requests = spies.get_client_requests( list_templates_requests = spies.get_client_requests(method="resources/templates/list")
method="resources/templates/list"
)
assert len(list_templates_requests) == 1 assert len(list_templates_requests) == 1
assert list_templates_requests[0].params is None assert list_templates_requests[0].params is None
@@ -192,9 +190,7 @@ async def test_list_resource_templates_cursor_parameter(stream_spy):
# Test with cursor=None # Test with cursor=None
_ = await client_session.list_resource_templates(cursor=None) _ = await client_session.list_resource_templates(cursor=None)
list_templates_requests = spies.get_client_requests( list_templates_requests = spies.get_client_requests(method="resources/templates/list")
method="resources/templates/list"
)
assert len(list_templates_requests) == 1 assert len(list_templates_requests) == 1
assert list_templates_requests[0].params is None assert list_templates_requests[0].params is None
@@ -202,9 +198,7 @@ async def test_list_resource_templates_cursor_parameter(stream_spy):
# Test with cursor as string # Test with cursor as string
_ = await client_session.list_resource_templates(cursor="some_cursor") _ = await client_session.list_resource_templates(cursor="some_cursor")
list_templates_requests = spies.get_client_requests( list_templates_requests = spies.get_client_requests(method="resources/templates/list")
method="resources/templates/list"
)
assert len(list_templates_requests) == 1 assert len(list_templates_requests) == 1
assert list_templates_requests[0].params is not None assert list_templates_requests[0].params is not None
assert list_templates_requests[0].params["cursor"] == "some_cursor" assert list_templates_requests[0].params["cursor"] == "some_cursor"
@@ -213,9 +207,7 @@ async def test_list_resource_templates_cursor_parameter(stream_spy):
# Test with empty string cursor # Test with empty string cursor
_ = await client_session.list_resource_templates(cursor="") _ = await client_session.list_resource_templates(cursor="")
list_templates_requests = spies.get_client_requests( list_templates_requests = spies.get_client_requests(method="resources/templates/list")
method="resources/templates/list"
)
assert len(list_templates_requests) == 1 assert len(list_templates_requests) == 1
assert list_templates_requests[0].params is not None assert list_templates_requests[0].params is not None
assert list_templates_requests[0].params["cursor"] == "" assert list_templates_requests[0].params["cursor"] == ""

View File

@@ -41,13 +41,9 @@ async def test_list_roots_callback():
return True return True
# Test with list_roots callback # Test with list_roots callback
async with create_session( async with create_session(server._mcp_server, list_roots_callback=list_roots_callback) as client_session:
server._mcp_server, list_roots_callback=list_roots_callback
) as client_session:
# Make a request to trigger sampling callback # Make a request to trigger sampling callback
result = await client_session.call_tool( result = await client_session.call_tool("test_list_roots", {"message": "test message"})
"test_list_roots", {"message": "test message"}
)
assert result.isError is False assert result.isError is False
assert isinstance(result.content[0], TextContent) assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true" assert result.content[0].text == "true"
@@ -55,12 +51,7 @@ async def test_list_roots_callback():
# Test without list_roots callback # Test without list_roots callback
async with create_session(server._mcp_server) as client_session: async with create_session(server._mcp_server) as client_session:
# Make a request to trigger sampling callback # Make a request to trigger sampling callback
result = await client_session.call_tool( result = await client_session.call_tool("test_list_roots", {"message": "test message"})
"test_list_roots", {"message": "test message"}
)
assert result.isError is True assert result.isError is True
assert isinstance(result.content[0], TextContent) assert isinstance(result.content[0], TextContent)
assert ( assert result.content[0].text == "Error executing tool test_list_roots: List roots not supported"
result.content[0].text
== "Error executing tool test_list_roots: List roots not supported"
)

View File

@@ -49,9 +49,7 @@ async def test_logging_callback():
# Create a message handler to catch exceptions # Create a message handler to catch exceptions
async def message_handler( async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
| types.ServerNotification
| Exception,
) -> None: ) -> None:
if isinstance(message, Exception): if isinstance(message, Exception):
raise message raise message

View File

@@ -21,9 +21,7 @@ async def test_sampling_callback():
callback_return = CreateMessageResult( callback_return = CreateMessageResult(
role="assistant", role="assistant",
content=TextContent( content=TextContent(type="text", text="This is a response from the sampling callback"),
type="text", text="This is a response from the sampling callback"
),
model="test-model", model="test-model",
stopReason="endTurn", stopReason="endTurn",
) )
@@ -37,24 +35,16 @@ async def test_sampling_callback():
@server.tool("test_sampling") @server.tool("test_sampling")
async def test_sampling_tool(message: str): async def test_sampling_tool(message: str):
value = await server.get_context().session.create_message( value = await server.get_context().session.create_message(
messages=[ messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
SamplingMessage(
role="user", content=TextContent(type="text", text=message)
)
],
max_tokens=100, max_tokens=100,
) )
assert value == callback_return assert value == callback_return
return True return True
# Test with sampling callback # Test with sampling callback
async with create_session( async with create_session(server._mcp_server, sampling_callback=sampling_callback) as client_session:
server._mcp_server, sampling_callback=sampling_callback
) as client_session:
# Make a request to trigger sampling callback # Make a request to trigger sampling callback
result = await client_session.call_tool( result = await client_session.call_tool("test_sampling", {"message": "Test message for sampling"})
"test_sampling", {"message": "Test message for sampling"}
)
assert result.isError is False assert result.isError is False
assert isinstance(result.content[0], TextContent) assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true" assert result.content[0].text == "true"
@@ -62,12 +52,7 @@ async def test_sampling_callback():
# Test without sampling callback # Test without sampling callback
async with create_session(server._mcp_server) as client_session: async with create_session(server._mcp_server) as client_session:
# Make a request to trigger sampling callback # Make a request to trigger sampling callback
result = await client_session.call_tool( result = await client_session.call_tool("test_sampling", {"message": "Test message for sampling"})
"test_sampling", {"message": "Test message for sampling"}
)
assert result.isError is True assert result.isError is True
assert isinstance(result.content[0], TextContent) assert isinstance(result.content[0], TextContent)
assert ( assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported"
result.content[0].text
== "Error executing tool test_sampling: Sampling not supported"
)

View File

@@ -28,12 +28,8 @@ from mcp.types import (
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_session_initialize(): async def test_client_session_initialize():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
SessionMessage server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
initialized_notification = None initialized_notification = None
@@ -70,9 +66,7 @@ async def test_client_session_initialize():
JSONRPCResponse( JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
id=jsonrpc_request.root.id, id=jsonrpc_request.root.id,
result=result.model_dump( result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
by_alias=True, mode="json", exclude_none=True
),
) )
) )
) )
@@ -81,16 +75,12 @@ async def test_client_session_initialize():
jsonrpc_notification = session_notification.message jsonrpc_notification = session_notification.message
assert isinstance(jsonrpc_notification.root, JSONRPCNotification) assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
initialized_notification = ClientNotification.model_validate( initialized_notification = ClientNotification.model_validate(
jsonrpc_notification.model_dump( jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True)
by_alias=True, mode="json", exclude_none=True
)
) )
# Create a message handler to catch exceptions # Create a message handler to catch exceptions
async def message_handler( async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
| types.ServerNotification
| Exception,
) -> None: ) -> None:
if isinstance(message, Exception): if isinstance(message, Exception):
raise message raise message
@@ -124,12 +114,8 @@ async def test_client_session_initialize():
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_session_custom_client_info(): async def test_client_session_custom_client_info():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
SessionMessage server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
custom_client_info = Implementation(name="test-client", version="1.2.3") custom_client_info = Implementation(name="test-client", version="1.2.3")
received_client_info = None received_client_info = None
@@ -161,9 +147,7 @@ async def test_client_session_custom_client_info():
JSONRPCResponse( JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
id=jsonrpc_request.root.id, id=jsonrpc_request.root.id,
result=result.model_dump( result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
by_alias=True, mode="json", exclude_none=True
),
) )
) )
) )
@@ -192,12 +176,8 @@ async def test_client_session_custom_client_info():
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_session_default_client_info(): async def test_client_session_default_client_info():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
SessionMessage server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
received_client_info = None received_client_info = None
@@ -228,9 +208,7 @@ async def test_client_session_default_client_info():
JSONRPCResponse( JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
id=jsonrpc_request.root.id, id=jsonrpc_request.root.id,
result=result.model_dump( result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
by_alias=True, mode="json", exclude_none=True
),
) )
) )
) )
@@ -259,12 +237,8 @@ async def test_client_session_default_client_info():
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_session_version_negotiation_success(): async def test_client_session_version_negotiation_success():
"""Test successful version negotiation with supported version""" """Test successful version negotiation with supported version"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
SessionMessage server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
async def mock_server(): async def mock_server():
session_message = await client_to_server_receive.receive() session_message = await client_to_server_receive.receive()
@@ -294,9 +268,7 @@ async def test_client_session_version_negotiation_success():
JSONRPCResponse( JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
id=jsonrpc_request.root.id, id=jsonrpc_request.root.id,
result=result.model_dump( result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
by_alias=True, mode="json", exclude_none=True
),
) )
) )
) )
@@ -327,12 +299,8 @@ async def test_client_session_version_negotiation_success():
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_session_version_negotiation_failure(): async def test_client_session_version_negotiation_failure():
"""Test version negotiation failure with unsupported version""" """Test version negotiation failure with unsupported version"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
SessionMessage server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
async def mock_server(): async def mock_server():
session_message = await client_to_server_receive.receive() session_message = await client_to_server_receive.receive()
@@ -359,9 +327,7 @@ async def test_client_session_version_negotiation_failure():
JSONRPCResponse( JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
id=jsonrpc_request.root.id, id=jsonrpc_request.root.id,
result=result.model_dump( result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
by_alias=True, mode="json", exclude_none=True
),
) )
) )
) )
@@ -388,12 +354,8 @@ async def test_client_session_version_negotiation_failure():
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_capabilities_default(): async def test_client_capabilities_default():
"""Test that client capabilities are properly set with default callbacks""" """Test that client capabilities are properly set with default callbacks"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
SessionMessage server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
received_capabilities = None received_capabilities = None
@@ -424,9 +386,7 @@ async def test_client_capabilities_default():
JSONRPCResponse( JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
id=jsonrpc_request.root.id, id=jsonrpc_request.root.id,
result=result.model_dump( result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
by_alias=True, mode="json", exclude_none=True
),
) )
) )
) )
@@ -457,12 +417,8 @@ async def test_client_capabilities_default():
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_capabilities_with_custom_callbacks(): async def test_client_capabilities_with_custom_callbacks():
"""Test that client capabilities are properly set with custom callbacks""" """Test that client capabilities are properly set with custom callbacks"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
SessionMessage server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
received_capabilities = None received_capabilities = None
@@ -508,9 +464,7 @@ async def test_client_capabilities_with_custom_callbacks():
JSONRPCResponse( JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
id=jsonrpc_request.root.id, id=jsonrpc_request.root.id,
result=result.model_dump( result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
by_alias=True, mode="json", exclude_none=True
),
) )
) )
) )
@@ -536,14 +490,8 @@ async def test_client_capabilities_with_custom_callbacks():
# Assert that capabilities are properly set with custom callbacks # Assert that capabilities are properly set with custom callbacks
assert received_capabilities is not None assert received_capabilities is not None
assert ( assert received_capabilities.sampling is not None # Custom sampling callback provided
received_capabilities.sampling is not None
) # Custom sampling callback provided
assert isinstance(received_capabilities.sampling, types.SamplingCapability) assert isinstance(received_capabilities.sampling, types.SamplingCapability)
assert ( assert received_capabilities.roots is not None # Custom list_roots callback provided
received_capabilities.roots is not None
) # Custom list_roots callback provided
assert isinstance(received_capabilities.roots, types.RootsCapability) assert isinstance(received_capabilities.roots, types.RootsCapability)
assert ( assert received_capabilities.roots.listChanged is True # Should be True for custom callback
received_capabilities.roots.listChanged is True
) # Should be True for custom callback

View File

@@ -58,14 +58,10 @@ class TestClientSessionGroup:
return f"{(server_info.name)}-{name}" return f"{(server_info.name)}-{name}"
mcp_session_group = ClientSessionGroup(component_name_hook=hook) mcp_session_group = ClientSessionGroup(component_name_hook=hook)
mcp_session_group._tools = { mcp_session_group._tools = {"server1-my_tool": types.Tool(name="my_tool", inputSchema={})}
"server1-my_tool": types.Tool(name="my_tool", inputSchema={})
}
mcp_session_group._tool_to_session = {"server1-my_tool": mock_session} mcp_session_group._tool_to_session = {"server1-my_tool": mock_session}
text_content = types.TextContent(type="text", text="OK") text_content = types.TextContent(type="text", text="OK")
mock_session.call_tool.return_value = types.CallToolResult( mock_session.call_tool.return_value = types.CallToolResult(content=[text_content])
content=[text_content]
)
# --- Test Execution --- # --- Test Execution ---
result = await mcp_session_group.call_tool( result = await mcp_session_group.call_tool(
@@ -96,16 +92,12 @@ class TestClientSessionGroup:
mock_prompt1 = mock.Mock(spec=types.Prompt) mock_prompt1 = mock.Mock(spec=types.Prompt)
mock_prompt1.name = "prompt_c" mock_prompt1.name = "prompt_c"
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1]) mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1])
mock_session.list_resources.return_value = mock.AsyncMock( mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1])
resources=[mock_resource1]
)
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1]) mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1])
# --- Test Execution --- # --- Test Execution ---
group = ClientSessionGroup(exit_stack=mock_exit_stack) group = ClientSessionGroup(exit_stack=mock_exit_stack)
with mock.patch.object( with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
group, "_establish_session", return_value=(mock_server_info, mock_session)
):
await group.connect_to_server(StdioServerParameters(command="test")) await group.connect_to_server(StdioServerParameters(command="test"))
# --- Assertions --- # --- Assertions ---
@@ -141,12 +133,8 @@ class TestClientSessionGroup:
return f"{server_info.name}.{name}" return f"{server_info.name}.{name}"
# --- Test Execution --- # --- Test Execution ---
group = ClientSessionGroup( group = ClientSessionGroup(exit_stack=mock_exit_stack, component_name_hook=name_hook)
exit_stack=mock_exit_stack, component_name_hook=name_hook with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
)
with mock.patch.object(
group, "_establish_session", return_value=(mock_server_info, mock_session)
):
await group.connect_to_server(StdioServerParameters(command="test")) await group.connect_to_server(StdioServerParameters(command="test"))
# --- Assertions --- # --- Assertions ---
@@ -231,9 +219,7 @@ class TestClientSessionGroup:
# Need a dummy session associated with the existing tool # Need a dummy session associated with the existing tool
mock_session = mock.MagicMock(spec=mcp.ClientSession) mock_session = mock.MagicMock(spec=mcp.ClientSession)
group._tool_to_session[existing_tool_name] = mock_session group._tool_to_session[existing_tool_name] = mock_session
group._session_exit_stacks[mock_session] = mock.Mock( group._session_exit_stacks[mock_session] = mock.Mock(spec=contextlib.AsyncExitStack)
spec=contextlib.AsyncExitStack
)
# --- Mock New Connection Attempt --- # --- Mock New Connection Attempt ---
mock_server_info_new = mock.Mock(spec=types.Implementation) mock_server_info_new = mock.Mock(spec=types.Implementation)
@@ -243,9 +229,7 @@ class TestClientSessionGroup:
# Configure the new session to return a tool with the *same name* # Configure the new session to return a tool with the *same name*
duplicate_tool = mock.Mock(spec=types.Tool) duplicate_tool = mock.Mock(spec=types.Tool)
duplicate_tool.name = existing_tool_name duplicate_tool.name = existing_tool_name
mock_session_new.list_tools.return_value = mock.AsyncMock( mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool])
tools=[duplicate_tool]
)
# Keep other lists empty for simplicity # Keep other lists empty for simplicity
mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[]) mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[])
mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[]) mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[])
@@ -266,9 +250,7 @@ class TestClientSessionGroup:
# Verify the duplicate tool was *not* added again (state should be unchanged) # Verify the duplicate tool was *not* added again (state should be unchanged)
assert len(group._tools) == 1 # Should still only have the original assert len(group._tools) == 1 # Should still only have the original
assert ( assert group._tools[existing_tool_name] is not duplicate_tool # Ensure it's the original mock
group._tools[existing_tool_name] is not duplicate_tool
) # Ensure it's the original mock
# No patching needed here # No patching needed here
async def test_disconnect_non_existent_server(self): async def test_disconnect_non_existent_server(self):
@@ -292,9 +274,7 @@ class TestClientSessionGroup:
"mcp.client.session_group.sse_client", "mcp.client.session_group.sse_client",
), # url, headers, timeout, sse_read_timeout ), # url, headers, timeout, sse_read_timeout
( (
StreamableHttpParameters( StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False),
url="http://test.com/stream", terminate_on_close=False
),
"streamablehttp", "streamablehttp",
"mcp.client.session_group.streamablehttp_client", "mcp.client.session_group.streamablehttp_client",
), # url, headers, timeout, sse_read_timeout, terminate_on_close ), # url, headers, timeout, sse_read_timeout, terminate_on_close
@@ -306,13 +286,9 @@ class TestClientSessionGroup:
client_type_name, # Just for clarity or conditional logic if needed client_type_name, # Just for clarity or conditional logic if needed
patch_target_for_client_func, patch_target_for_client_func,
): ):
with mock.patch( with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class:
"mcp.client.session_group.mcp.ClientSession"
) as mock_ClientSession_class:
with mock.patch(patch_target_for_client_func) as mock_specific_client_func: with mock.patch(patch_target_for_client_func) as mock_specific_client_func:
mock_client_cm_instance = mock.AsyncMock( mock_client_cm_instance = mock.AsyncMock(name=f"{client_type_name}ClientCM")
name=f"{client_type_name}ClientCM"
)
mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read") mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read")
mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write") mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write")
@@ -344,9 +320,7 @@ class TestClientSessionGroup:
# Mock session.initialize() # Mock session.initialize()
mock_initialize_result = mock.AsyncMock(name="InitializeResult") mock_initialize_result = mock.AsyncMock(name="InitializeResult")
mock_initialize_result.serverInfo = types.Implementation( mock_initialize_result.serverInfo = types.Implementation(name="foo", version="1")
name="foo", version="1"
)
mock_entered_session.initialize.return_value = mock_initialize_result mock_entered_session.initialize.return_value = mock_initialize_result
# --- Test Execution --- # --- Test Execution ---
@@ -364,9 +338,7 @@ class TestClientSessionGroup:
# --- Assertions --- # --- Assertions ---
# 1. Assert the correct specific client function was called # 1. Assert the correct specific client function was called
if client_type_name == "stdio": if client_type_name == "stdio":
mock_specific_client_func.assert_called_once_with( mock_specific_client_func.assert_called_once_with(server_params_instance)
server_params_instance
)
elif client_type_name == "sse": elif client_type_name == "sse":
mock_specific_client_func.assert_called_once_with( mock_specific_client_func.assert_called_once_with(
url=server_params_instance.url, url=server_params_instance.url,
@@ -386,9 +358,7 @@ class TestClientSessionGroup:
mock_client_cm_instance.__aenter__.assert_awaited_once() mock_client_cm_instance.__aenter__.assert_awaited_once()
# 2. Assert ClientSession was called correctly # 2. Assert ClientSession was called correctly
mock_ClientSession_class.assert_called_once_with( mock_ClientSession_class.assert_called_once_with(mock_read_stream, mock_write_stream)
mock_read_stream, mock_write_stream
)
mock_raw_session_cm.__aenter__.assert_awaited_once() mock_raw_session_cm.__aenter__.assert_awaited_once()
mock_entered_session.initialize.assert_awaited_once() mock_entered_session.initialize.assert_awaited_once()

View File

@@ -50,20 +50,14 @@ async def test_stdio_client():
break break
assert len(read_messages) == 2 assert len(read_messages) == 2
assert read_messages[0] == JSONRPCMessage( assert read_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") assert read_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}))
)
assert read_messages[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
)
@pytest.mark.anyio @pytest.mark.anyio
async def test_stdio_client_bad_path(): async def test_stdio_client_bad_path():
"""Check that the connection doesn't hang if process errors.""" """Check that the connection doesn't hang if process errors."""
server_params = StdioServerParameters( server_params = StdioServerParameters(command="python", args=["-c", "non-existent-file.py"])
command="python", args=["-c", "non-existent-file.py"]
)
async with stdio_client(server_params) as (read_stream, write_stream): async with stdio_client(server_params) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session: async with ClientSession(read_stream, write_stream) as session:
# The session should raise an error when the connection closes # The session should raise an error when the connection closes

View File

@@ -17,9 +17,7 @@ async def test_list_tools_returns_all_tools():
f"""Tool number {i}""" f"""Tool number {i}"""
return i return i
globals()[f"dummy_tool_{i}"] = ( globals()[f"dummy_tool_{i}"] = dummy_tool_func # Keep reference to avoid garbage collection
dummy_tool_func # Keep reference to avoid garbage collection
)
# Get all tools # Get all tools
tools = await mcp.list_tools() tools = await mcp.list_tools()
@@ -30,6 +28,4 @@ async def test_list_tools_returns_all_tools():
# Verify each tool is unique and has the correct name # Verify each tool is unique and has the correct name
tool_names = [tool.name for tool in tools] tool_names = [tool.name for tool in tools]
expected_names = [f"tool_{i}" for i in range(num_tools)] expected_names = [f"tool_{i}" for i in range(num_tools)]
assert sorted(tool_names) == sorted( assert sorted(tool_names) == sorted(expected_names), "Tool names don't match expected names"
expected_names
), "Tool names don't match expected names"

View File

@@ -24,9 +24,7 @@ async def test_resource_templates():
# Note: list_resource_templates() returns a decorator that wraps the handler # Note: list_resource_templates() returns a decorator that wraps the handler
# The handler returns a ServerResult with a ListResourceTemplatesResult inside # The handler returns a ServerResult with a ListResourceTemplatesResult inside
result = await mcp._mcp_server.request_handlers[types.ListResourceTemplatesRequest]( result = await mcp._mcp_server.request_handlers[types.ListResourceTemplatesRequest](
types.ListResourceTemplatesRequest( types.ListResourceTemplatesRequest(method="resources/templates/list", params=None)
method="resources/templates/list", params=None
)
) )
assert isinstance(result.root, types.ListResourceTemplatesResult) assert isinstance(result.root, types.ListResourceTemplatesResult)
templates = result.root.resourceTemplates templates = result.root.resourceTemplates

View File

@@ -61,9 +61,7 @@ async def test_resource_template_edge_cases():
await mcp.read_resource("resource://users/123/posts") # Missing post_id await mcp.read_resource("resource://users/123/posts") # Missing post_id
with pytest.raises(ValueError, match="Unknown resource"): with pytest.raises(ValueError, match="Unknown resource"):
await mcp.read_resource( await mcp.read_resource("resource://users/123/posts/456/extra") # Extra path component
"resource://users/123/posts/456/extra"
) # Extra path component
@pytest.mark.anyio @pytest.mark.anyio
@@ -110,11 +108,7 @@ async def test_resource_template_client_interaction():
# Verify invalid resource URIs raise appropriate errors # Verify invalid resource URIs raise appropriate errors
with pytest.raises(Exception): # Specific exception type may vary with pytest.raises(Exception): # Specific exception type may vary
await session.read_resource( await session.read_resource(AnyUrl("resource://users/123/posts")) # Missing post_id
AnyUrl("resource://users/123/posts")
) # Missing post_id
with pytest.raises(Exception): # Specific exception type may vary with pytest.raises(Exception): # Specific exception type may vary
await session.read_resource( await session.read_resource(AnyUrl("resource://users/123/invalid")) # Invalid template
AnyUrl("resource://users/123/invalid")
) # Invalid template

View File

@@ -45,31 +45,19 @@ async def test_fastmcp_resource_mime_type():
bytes_resource = mapping["test://image_bytes"] bytes_resource = mapping["test://image_bytes"]
# Verify mime types # Verify mime types
assert ( assert string_resource.mimeType == "image/png", "String resource mime type not respected"
string_resource.mimeType == "image/png" assert bytes_resource.mimeType == "image/png", "Bytes resource mime type not respected"
), "String resource mime type not respected"
assert (
bytes_resource.mimeType == "image/png"
), "Bytes resource mime type not respected"
# Also verify the content can be read correctly # Also verify the content can be read correctly
string_result = await client.read_resource(AnyUrl("test://image")) string_result = await client.read_resource(AnyUrl("test://image"))
assert len(string_result.contents) == 1 assert len(string_result.contents) == 1
assert ( assert getattr(string_result.contents[0], "text") == base64_string, "Base64 string mismatch"
getattr(string_result.contents[0], "text") == base64_string assert string_result.contents[0].mimeType == "image/png", "String content mime type not preserved"
), "Base64 string mismatch"
assert (
string_result.contents[0].mimeType == "image/png"
), "String content mime type not preserved"
bytes_result = await client.read_resource(AnyUrl("test://image_bytes")) bytes_result = await client.read_resource(AnyUrl("test://image_bytes"))
assert len(bytes_result.contents) == 1 assert len(bytes_result.contents) == 1
assert ( assert base64.b64decode(getattr(bytes_result.contents[0], "blob")) == image_bytes, "Bytes mismatch"
base64.b64decode(getattr(bytes_result.contents[0], "blob")) == image_bytes assert bytes_result.contents[0].mimeType == "image/png", "Bytes content mime type not preserved"
), "Bytes mismatch"
assert (
bytes_result.contents[0].mimeType == "image/png"
), "Bytes content mime type not preserved"
async def test_lowlevel_resource_mime_type(): async def test_lowlevel_resource_mime_type():
@@ -82,9 +70,7 @@ async def test_lowlevel_resource_mime_type():
# Create test resources with specific mime types # Create test resources with specific mime types
test_resources = [ test_resources = [
types.Resource( types.Resource(uri=AnyUrl("test://image"), name="test image", mimeType="image/png"),
uri=AnyUrl("test://image"), name="test image", mimeType="image/png"
),
types.Resource( types.Resource(
uri=AnyUrl("test://image_bytes"), uri=AnyUrl("test://image_bytes"),
name="test image bytes", name="test image bytes",
@@ -101,9 +87,7 @@ async def test_lowlevel_resource_mime_type():
if str(uri) == "test://image": if str(uri) == "test://image":
return [ReadResourceContents(content=base64_string, mime_type="image/png")] return [ReadResourceContents(content=base64_string, mime_type="image/png")]
elif str(uri) == "test://image_bytes": elif str(uri) == "test://image_bytes":
return [ return [ReadResourceContents(content=bytes(image_bytes), mime_type="image/png")]
ReadResourceContents(content=bytes(image_bytes), mime_type="image/png")
]
raise Exception(f"Resource not found: {uri}") raise Exception(f"Resource not found: {uri}")
# Test that resources are listed with correct mime type # Test that resources are listed with correct mime type
@@ -119,28 +103,16 @@ async def test_lowlevel_resource_mime_type():
bytes_resource = mapping["test://image_bytes"] bytes_resource = mapping["test://image_bytes"]
# Verify mime types # Verify mime types
assert ( assert string_resource.mimeType == "image/png", "String resource mime type not respected"
string_resource.mimeType == "image/png" assert bytes_resource.mimeType == "image/png", "Bytes resource mime type not respected"
), "String resource mime type not respected"
assert (
bytes_resource.mimeType == "image/png"
), "Bytes resource mime type not respected"
# Also verify the content can be read correctly # Also verify the content can be read correctly
string_result = await client.read_resource(AnyUrl("test://image")) string_result = await client.read_resource(AnyUrl("test://image"))
assert len(string_result.contents) == 1 assert len(string_result.contents) == 1
assert ( assert getattr(string_result.contents[0], "text") == base64_string, "Base64 string mismatch"
getattr(string_result.contents[0], "text") == base64_string assert string_result.contents[0].mimeType == "image/png", "String content mime type not preserved"
), "Base64 string mismatch"
assert (
string_result.contents[0].mimeType == "image/png"
), "String content mime type not preserved"
bytes_result = await client.read_resource(AnyUrl("test://image_bytes")) bytes_result = await client.read_resource(AnyUrl("test://image_bytes"))
assert len(bytes_result.contents) == 1 assert len(bytes_result.contents) == 1
assert ( assert base64.b64decode(getattr(bytes_result.contents[0], "blob")) == image_bytes, "Bytes mismatch"
base64.b64decode(getattr(bytes_result.contents[0], "blob")) == image_bytes assert bytes_result.contents[0].mimeType == "image/png", "Bytes content mime type not preserved"
), "Bytes mismatch"
assert (
bytes_result.contents[0].mimeType == "image/png"
), "Bytes content mime type not preserved"

View File

@@ -35,15 +35,7 @@ async def test_progress_token_zero_first_call():
await ctx.report_progress(10, 10) # Complete await ctx.report_progress(10, 10) # Complete
# Verify progress notifications # Verify progress notifications
assert ( assert mock_session.send_progress_notification.call_count == 3, "All progress notifications should be sent"
mock_session.send_progress_notification.call_count == 3 mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=0.0, total=10.0, message=None)
), "All progress notifications should be sent" mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=5.0, total=10.0, message=None)
mock_session.send_progress_notification.assert_any_call( mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=10.0, total=10.0, message=None)
progress_token=0, progress=0.0, total=10.0, message=None
)
mock_session.send_progress_notification.assert_any_call(
progress_token=0, progress=5.0, total=10.0, message=None
)
mock_session.send_progress_notification.assert_any_call(
progress_token=0, progress=10.0, total=10.0, message=None
)

View File

@@ -66,9 +66,7 @@ async def test_request_id_match() -> None:
) )
await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req))) await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req)))
response = ( response = await server_reader.receive() # Get init response but don't need to check it
await server_reader.receive()
) # Get init response but don't need to check it
# Send initialized notification # Send initialized notification
initialized_notification = JSONRPCNotification( initialized_notification = JSONRPCNotification(
@@ -76,14 +74,10 @@ async def test_request_id_match() -> None:
params=NotificationParams().model_dump(by_alias=True, exclude_none=True), params=NotificationParams().model_dump(by_alias=True, exclude_none=True),
jsonrpc="2.0", jsonrpc="2.0",
) )
await client_writer.send( await client_writer.send(SessionMessage(JSONRPCMessage(root=initialized_notification)))
SessionMessage(JSONRPCMessage(root=initialized_notification))
)
# Send ping request with custom ID # Send ping request with custom ID
ping_request = JSONRPCRequest( ping_request = JSONRPCRequest(id=custom_request_id, method="ping", params={}, jsonrpc="2.0")
id=custom_request_id, method="ping", params={}, jsonrpc="2.0"
)
await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request))) await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request)))
@@ -91,9 +85,7 @@ async def test_request_id_match() -> None:
response = await server_reader.receive() response = await server_reader.receive()
# Verify response ID matches request ID # Verify response ID matches request ID
assert ( assert response.message.root.id == custom_request_id, "Response ID should match request ID"
response.message.root.id == custom_request_id
), "Response ID should match request ID"
# Cancel server task # Cancel server task
tg.cancel_scope.cancel() tg.cancel_scope.cancel()

View File

@@ -47,11 +47,7 @@ async def test_server_base64_encoding_issue():
# Register a resource handler that returns our test data # Register a resource handler that returns our test data
@server.read_resource() @server.read_resource()
async def read_resource(uri: AnyUrl) -> list[ReadResourceContents]: async def read_resource(uri: AnyUrl) -> list[ReadResourceContents]:
return [ return [ReadResourceContents(content=binary_data, mime_type="application/octet-stream")]
ReadResourceContents(
content=binary_data, mime_type="application/octet-stream"
)
]
# Get the handler directly from the server # Get the handler directly from the server
handler = server.request_handlers[ReadResourceRequest] handler = server.request_handlers[ReadResourceRequest]

View File

@@ -11,12 +11,7 @@ from anyio.abc import TaskStatus
from mcp.client.session import ClientSession from mcp.client.session import ClientSession
from mcp.server.lowlevel import Server from mcp.server.lowlevel import Server
from mcp.shared.exceptions import McpError from mcp.shared.exceptions import McpError
from mcp.types import ( from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent
AudioContent,
EmbeddedResource,
ImageContent,
TextContent,
)
@pytest.mark.anyio @pytest.mark.anyio
@@ -36,9 +31,7 @@ async def test_notification_validation_error(tmp_path: Path):
slow_request_complete = anyio.Event() slow_request_complete = anyio.Event()
@server.call_tool() @server.call_tool()
async def slow_tool( async def slow_tool(name: str, arg) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
name: str, arg
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
nonlocal request_count nonlocal request_count
request_count += 1 request_count += 1
@@ -75,9 +68,7 @@ async def test_notification_validation_error(tmp_path: Path):
# - Long enough for fast operations (>10ms) # - Long enough for fast operations (>10ms)
# - Short enough for slow operations (<200ms) # - Short enough for slow operations (<200ms)
# - Not too short to avoid flakiness # - Not too short to avoid flakiness
async with ClientSession( async with ClientSession(read_stream, write_stream, read_timeout_seconds=timedelta(milliseconds=50)) as session:
read_stream, write_stream, read_timeout_seconds=timedelta(milliseconds=50)
) as session:
await session.initialize() await session.initialize()
# First call should work (fast operation) # First call should work (fast operation)

View File

@@ -1,4 +1,4 @@
# Claude Debug # Claude Debug
"""Test for HackerOne vulnerability report #3156202 - malformed input DOS.""" """Test for HackerOne vulnerability report #3156202 - malformed input DOS."""
import anyio import anyio
@@ -23,12 +23,8 @@ async def test_malformed_initialize_request_does_not_crash_server():
instead of crashing the server (HackerOne #3156202). instead of crashing the server (HackerOne #3156202).
""" """
# Create in-memory streams for testing # Create in-memory streams for testing
read_send_stream, read_receive_stream = anyio.create_memory_object_stream[ read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10)
SessionMessage | Exception write_send_stream, write_receive_stream = anyio.create_memory_object_stream[SessionMessage](10)
](10)
write_send_stream, write_receive_stream = anyio.create_memory_object_stream[
SessionMessage
](10)
try: try:
# Create a malformed initialize request (missing required params field) # Create a malformed initialize request (missing required params field)
@@ -38,7 +34,7 @@ async def test_malformed_initialize_request_does_not_crash_server():
method="initialize", method="initialize",
# params=None # Missing required params field # params=None # Missing required params field
) )
# Wrap in session message # Wrap in session message
request_message = SessionMessage(message=JSONRPCMessage(malformed_request)) request_message = SessionMessage(message=JSONRPCMessage(malformed_request))
@@ -54,22 +50,22 @@ async def test_malformed_initialize_request_does_not_crash_server():
): ):
# Send the malformed request # Send the malformed request
await read_send_stream.send(request_message) await read_send_stream.send(request_message)
# Give the session time to process the request # Give the session time to process the request
await anyio.sleep(0.1) await anyio.sleep(0.1)
# Check that we received an error response instead of a crash # Check that we received an error response instead of a crash
try: try:
response_message = write_receive_stream.receive_nowait() response_message = write_receive_stream.receive_nowait()
response = response_message.message.root response = response_message.message.root
# Verify it's a proper JSON-RPC error response # Verify it's a proper JSON-RPC error response
assert isinstance(response, JSONRPCError) assert isinstance(response, JSONRPCError)
assert response.jsonrpc == "2.0" assert response.jsonrpc == "2.0"
assert response.id == "f20fe86132ed4cd197f89a7134de5685" assert response.id == "f20fe86132ed4cd197f89a7134de5685"
assert response.error.code == INVALID_PARAMS assert response.error.code == INVALID_PARAMS
assert "Invalid request parameters" in response.error.message assert "Invalid request parameters" in response.error.message
# Verify the session is still alive and can handle more requests # Verify the session is still alive and can handle more requests
# Send another malformed request to confirm server stability # Send another malformed request to confirm server stability
another_malformed_request = JSONRPCRequest( another_malformed_request = JSONRPCRequest(
@@ -78,21 +74,19 @@ async def test_malformed_initialize_request_does_not_crash_server():
method="tools/call", method="tools/call",
# params=None # Missing required params # params=None # Missing required params
) )
another_request_message = SessionMessage( another_request_message = SessionMessage(message=JSONRPCMessage(another_malformed_request))
message=JSONRPCMessage(another_malformed_request)
)
await read_send_stream.send(another_request_message) await read_send_stream.send(another_request_message)
await anyio.sleep(0.1) await anyio.sleep(0.1)
# Should get another error response, not a crash # Should get another error response, not a crash
second_response_message = write_receive_stream.receive_nowait() second_response_message = write_receive_stream.receive_nowait()
second_response = second_response_message.message.root second_response = second_response_message.message.root
assert isinstance(second_response, JSONRPCError) assert isinstance(second_response, JSONRPCError)
assert second_response.id == "test_id_2" assert second_response.id == "test_id_2"
assert second_response.error.code == INVALID_PARAMS assert second_response.error.code == INVALID_PARAMS
except anyio.WouldBlock: except anyio.WouldBlock:
pytest.fail("No response received - server likely crashed") pytest.fail("No response received - server likely crashed")
finally: finally:
@@ -109,12 +103,8 @@ async def test_multiple_concurrent_malformed_requests():
Test that multiple concurrent malformed requests don't crash the server. Test that multiple concurrent malformed requests don't crash the server.
""" """
# Create in-memory streams for testing # Create in-memory streams for testing
read_send_stream, read_receive_stream = anyio.create_memory_object_stream[ read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](100)
SessionMessage | Exception write_send_stream, write_receive_stream = anyio.create_memory_object_stream[SessionMessage](100)
](100)
write_send_stream, write_receive_stream = anyio.create_memory_object_stream[
SessionMessage
](100)
try: try:
# Start a server session # Start a server session
@@ -136,18 +126,16 @@ async def test_multiple_concurrent_malformed_requests():
method="initialize", method="initialize",
# params=None # Missing required params # params=None # Missing required params
) )
request_message = SessionMessage( request_message = SessionMessage(message=JSONRPCMessage(malformed_request))
message=JSONRPCMessage(malformed_request)
)
malformed_requests.append(request_message) malformed_requests.append(request_message)
# Send all requests # Send all requests
for request in malformed_requests: for request in malformed_requests:
await read_send_stream.send(request) await read_send_stream.send(request)
# Give time to process # Give time to process
await anyio.sleep(0.2) await anyio.sleep(0.2)
# Verify we get error responses for all requests # Verify we get error responses for all requests
error_responses = [] error_responses = []
try: try:
@@ -156,10 +144,10 @@ async def test_multiple_concurrent_malformed_requests():
error_responses.append(response_message.message.root) error_responses.append(response_message.message.root)
except anyio.WouldBlock: except anyio.WouldBlock:
pass # No more messages pass # No more messages
# Should have received 10 error responses # Should have received 10 error responses
assert len(error_responses) == 10 assert len(error_responses) == 10
for i, response in enumerate(error_responses): for i, response in enumerate(error_responses):
assert isinstance(response, JSONRPCError) assert isinstance(response, JSONRPCError)
assert response.id == f"malformed_{i}" assert response.id == f"malformed_{i}"
@@ -169,4 +157,4 @@ async def test_multiple_concurrent_malformed_requests():
await read_send_stream.aclose() await read_send_stream.aclose()
await write_send_stream.aclose() await write_send_stream.aclose()
await read_receive_stream.aclose() await read_receive_stream.aclose()
await write_receive_stream.aclose() await write_receive_stream.aclose()

View File

@@ -116,18 +116,14 @@ def no_expiry_access_token() -> AccessToken:
class TestBearerAuthBackend: class TestBearerAuthBackend:
"""Tests for the BearerAuthBackend class.""" """Tests for the BearerAuthBackend class."""
async def test_no_auth_header( async def test_no_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
):
"""Test authentication with no Authorization header.""" """Test authentication with no Authorization header."""
backend = BearerAuthBackend(provider=mock_oauth_provider) backend = BearerAuthBackend(provider=mock_oauth_provider)
request = Request({"type": "http", "headers": []}) request = Request({"type": "http", "headers": []})
result = await backend.authenticate(request) result = await backend.authenticate(request)
assert result is None assert result is None
async def test_non_bearer_auth_header( async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
):
"""Test authentication with non-Bearer Authorization header.""" """Test authentication with non-Bearer Authorization header."""
backend = BearerAuthBackend(provider=mock_oauth_provider) backend = BearerAuthBackend(provider=mock_oauth_provider)
request = Request( request = Request(
@@ -139,9 +135,7 @@ class TestBearerAuthBackend:
result = await backend.authenticate(request) result = await backend.authenticate(request)
assert result is None assert result is None
async def test_invalid_token( async def test_invalid_token(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
):
"""Test authentication with invalid token.""" """Test authentication with invalid token."""
backend = BearerAuthBackend(provider=mock_oauth_provider) backend = BearerAuthBackend(provider=mock_oauth_provider)
request = Request( request = Request(
@@ -160,9 +154,7 @@ class TestBearerAuthBackend:
): ):
"""Test authentication with expired token.""" """Test authentication with expired token."""
backend = BearerAuthBackend(provider=mock_oauth_provider) backend = BearerAuthBackend(provider=mock_oauth_provider)
add_token_to_provider( add_token_to_provider(mock_oauth_provider, "expired_token", expired_access_token)
mock_oauth_provider, "expired_token", expired_access_token
)
request = Request( request = Request(
{ {
"type": "http", "type": "http",
@@ -203,9 +195,7 @@ class TestBearerAuthBackend:
): ):
"""Test authentication with token that has no expiry.""" """Test authentication with token that has no expiry."""
backend = BearerAuthBackend(provider=mock_oauth_provider) backend = BearerAuthBackend(provider=mock_oauth_provider)
add_token_to_provider( add_token_to_provider(mock_oauth_provider, "no_expiry_token", no_expiry_access_token)
mock_oauth_provider, "no_expiry_token", no_expiry_access_token
)
request = Request( request = Request(
{ {
"type": "http", "type": "http",

View File

@@ -128,16 +128,12 @@ class TestRegistrationErrorHandling:
class TestAuthorizeErrorHandling: class TestAuthorizeErrorHandling:
@pytest.mark.anyio @pytest.mark.anyio
async def test_authorize_error_handling( async def test_authorize_error_handling(self, client, oauth_provider, registered_client, pkce_challenge):
self, client, oauth_provider, registered_client, pkce_challenge
):
# Mock the authorize method to raise an authorize error # Mock the authorize method to raise an authorize error
with unittest.mock.patch.object( with unittest.mock.patch.object(
oauth_provider, oauth_provider,
"authorize", "authorize",
side_effect=AuthorizeError( side_effect=AuthorizeError(error="access_denied", error_description="The user denied the request"),
error="access_denied", error_description="The user denied the request"
),
): ):
# Register the client # Register the client
client_id = registered_client["client_id"] client_id = registered_client["client_id"]
@@ -169,9 +165,7 @@ class TestAuthorizeErrorHandling:
class TestTokenErrorHandling: class TestTokenErrorHandling:
@pytest.mark.anyio @pytest.mark.anyio
async def test_token_error_handling_auth_code( async def test_token_error_handling_auth_code(self, client, oauth_provider, registered_client, pkce_challenge):
self, client, oauth_provider, registered_client, pkce_challenge
):
# Register the client and get an auth code # Register the client and get an auth code
client_id = registered_client["client_id"] client_id = registered_client["client_id"]
client_secret = registered_client["client_secret"] client_secret = registered_client["client_secret"]
@@ -224,9 +218,7 @@ class TestTokenErrorHandling:
assert data["error_description"] == "The authorization code is invalid" assert data["error_description"] == "The authorization code is invalid"
@pytest.mark.anyio @pytest.mark.anyio
async def test_token_error_handling_refresh_token( async def test_token_error_handling_refresh_token(self, client, oauth_provider, registered_client, pkce_challenge):
self, client, oauth_provider, registered_client, pkce_challenge
):
# Register the client and get tokens # Register the client and get tokens
client_id = registered_client["client_id"] client_id = registered_client["client_id"]
client_secret = registered_client["client_secret"] client_secret = registered_client["client_secret"]

View File

@@ -47,9 +47,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
async def register_client(self, client_info: OAuthClientInformationFull): async def register_client(self, client_info: OAuthClientInformationFull):
self.clients[client_info.client_id] = client_info self.clients[client_info.client_id] = client_info
async def authorize( async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
self, client: OAuthClientInformationFull, params: AuthorizationParams
) -> str:
# toy authorize implementation which just immediately generates an authorization # toy authorize implementation which just immediately generates an authorization
# code and completes the redirect # code and completes the redirect
code = AuthorizationCode( code = AuthorizationCode(
@@ -63,9 +61,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
) )
self.auth_codes[code.code] = code self.auth_codes[code.code] = code
return construct_redirect_uri( return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state)
str(params.redirect_uri), code=code.code, state=params.state
)
async def load_authorization_code( async def load_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: str self, client: OAuthClientInformationFull, authorization_code: str
@@ -102,9 +98,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
refresh_token=refresh_token, refresh_token=refresh_token,
) )
async def load_refresh_token( async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None:
self, client: OAuthClientInformationFull, refresh_token: str
) -> RefreshToken | None:
old_access_token = self.refresh_tokens.get(refresh_token) old_access_token = self.refresh_tokens.get(refresh_token)
if old_access_token is None: if old_access_token is None:
return None return None
@@ -224,9 +218,7 @@ def auth_app(mock_oauth_provider):
@pytest.fixture @pytest.fixture
async def test_client(auth_app): async def test_client(auth_app):
async with httpx.AsyncClient( async with httpx.AsyncClient(transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com") as client:
transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com"
) as client:
yield client yield client
@@ -261,11 +253,7 @@ async def registered_client(test_client: httpx.AsyncClient, request):
def pkce_challenge(): def pkce_challenge():
"""Create a PKCE challenge with code_verifier and code_challenge.""" """Create a PKCE challenge with code_verifier and code_challenge."""
code_verifier = "some_random_verifier_string" code_verifier = "some_random_verifier_string"
code_challenge = ( code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().rstrip("=")
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
.decode()
.rstrip("=")
)
return {"code_verifier": code_verifier, "code_challenge": code_challenge} return {"code_verifier": code_verifier, "code_challenge": code_challenge}
@@ -356,17 +344,13 @@ class TestAuthEndpoints:
metadata = response.json() metadata = response.json()
assert metadata["issuer"] == "https://auth.example.com/" assert metadata["issuer"] == "https://auth.example.com/"
assert ( assert metadata["authorization_endpoint"] == "https://auth.example.com/authorize"
metadata["authorization_endpoint"] == "https://auth.example.com/authorize"
)
assert metadata["token_endpoint"] == "https://auth.example.com/token" assert metadata["token_endpoint"] == "https://auth.example.com/token"
assert metadata["registration_endpoint"] == "https://auth.example.com/register" assert metadata["registration_endpoint"] == "https://auth.example.com/register"
assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke" assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke"
assert metadata["response_types_supported"] == ["code"] assert metadata["response_types_supported"] == ["code"]
assert metadata["code_challenge_methods_supported"] == ["S256"] assert metadata["code_challenge_methods_supported"] == ["S256"]
assert metadata["token_endpoint_auth_methods_supported"] == [ assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post"]
"client_secret_post"
]
assert metadata["grant_types_supported"] == [ assert metadata["grant_types_supported"] == [
"authorization_code", "authorization_code",
"refresh_token", "refresh_token",
@@ -386,14 +370,10 @@ class TestAuthEndpoints:
) )
error_response = response.json() error_response = response.json()
assert error_response["error"] == "invalid_request" assert error_response["error"] == "invalid_request"
assert ( assert "error_description" in error_response # Contains validation error messages
"error_description" in error_response
) # Contains validation error messages
@pytest.mark.anyio @pytest.mark.anyio
async def test_token_invalid_auth_code( async def test_token_invalid_auth_code(self, test_client, registered_client, pkce_challenge):
self, test_client, registered_client, pkce_challenge
):
"""Test token endpoint error - authorization code does not exist.""" """Test token endpoint error - authorization code does not exist."""
# Try to use a non-existent authorization code # Try to use a non-existent authorization code
response = await test_client.post( response = await test_client.post(
@@ -413,9 +393,7 @@ class TestAuthEndpoints:
assert response.status_code == 400 assert response.status_code == 400
error_response = response.json() error_response = response.json()
assert error_response["error"] == "invalid_grant" assert error_response["error"] == "invalid_grant"
assert ( assert "authorization code does not exist" in error_response["error_description"]
"authorization code does not exist" in error_response["error_description"]
)
@pytest.mark.anyio @pytest.mark.anyio
async def test_token_expired_auth_code( async def test_token_expired_auth_code(
@@ -458,9 +436,7 @@ class TestAuthEndpoints:
assert response.status_code == 400 assert response.status_code == 400
error_response = response.json() error_response = response.json()
assert error_response["error"] == "invalid_grant" assert error_response["error"] == "invalid_grant"
assert ( assert "authorization code has expired" in error_response["error_description"]
"authorization code has expired" in error_response["error_description"]
)
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -475,9 +451,7 @@ class TestAuthEndpoints:
], ],
indirect=True, indirect=True,
) )
async def test_token_redirect_uri_mismatch( async def test_token_redirect_uri_mismatch(self, test_client, registered_client, auth_code, pkce_challenge):
self, test_client, registered_client, auth_code, pkce_challenge
):
"""Test token endpoint error - redirect URI mismatch.""" """Test token endpoint error - redirect URI mismatch."""
# Try to use the code with a different redirect URI # Try to use the code with a different redirect URI
response = await test_client.post( response = await test_client.post(
@@ -498,9 +472,7 @@ class TestAuthEndpoints:
assert "redirect_uri did not match" in error_response["error_description"] assert "redirect_uri did not match" in error_response["error_description"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_token_code_verifier_mismatch( async def test_token_code_verifier_mismatch(self, test_client, registered_client, auth_code):
self, test_client, registered_client, auth_code
):
"""Test token endpoint error - PKCE code verifier mismatch.""" """Test token endpoint error - PKCE code verifier mismatch."""
# Try to use the code with an incorrect code verifier # Try to use the code with an incorrect code verifier
response = await test_client.post( response = await test_client.post(
@@ -569,9 +541,7 @@ class TestAuthEndpoints:
# Step 2: Time travel forward 4 hours (tokens expire in 1 hour by default) # Step 2: Time travel forward 4 hours (tokens expire in 1 hour by default)
# Mock the time.time() function to return a value 4 hours in the future # Mock the time.time() function to return a value 4 hours in the future
with unittest.mock.patch( with unittest.mock.patch("time.time", return_value=current_time + 14400): # 4 hours = 14400 seconds
"time.time", return_value=current_time + 14400
): # 4 hours = 14400 seconds
# Try to use the refresh token which should now be considered expired # Try to use the refresh token which should now be considered expired
response = await test_client.post( response = await test_client.post(
"/token", "/token",
@@ -590,9 +560,7 @@ class TestAuthEndpoints:
assert "refresh token has expired" in error_response["error_description"] assert "refresh token has expired" in error_response["error_description"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_token_invalid_scope( async def test_token_invalid_scope(self, test_client, registered_client, auth_code, pkce_challenge):
self, test_client, registered_client, auth_code, pkce_challenge
):
"""Test token endpoint error - invalid scope in refresh token request.""" """Test token endpoint error - invalid scope in refresh token request."""
# Exchange authorization code for tokens # Exchange authorization code for tokens
token_response = await test_client.post( token_response = await test_client.post(
@@ -628,9 +596,7 @@ class TestAuthEndpoints:
assert "cannot request scope" in error_response["error_description"] assert "cannot request scope" in error_response["error_description"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_registration( async def test_client_registration(self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider):
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider
):
"""Test client registration.""" """Test client registration."""
client_metadata = { client_metadata = {
"redirect_uris": ["https://client.example.com/callback"], "redirect_uris": ["https://client.example.com/callback"],
@@ -656,9 +622,7 @@ class TestAuthEndpoints:
# ) is not None # ) is not None
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_registration_missing_required_fields( async def test_client_registration_missing_required_fields(self, test_client: httpx.AsyncClient):
self, test_client: httpx.AsyncClient
):
"""Test client registration with missing required fields.""" """Test client registration with missing required fields."""
# Missing redirect_uris which is a required field # Missing redirect_uris which is a required field
client_metadata = { client_metadata = {
@@ -677,9 +641,7 @@ class TestAuthEndpoints:
assert error_data["error_description"] == "redirect_uris: Field required" assert error_data["error_description"] == "redirect_uris: Field required"
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_registration_invalid_uri( async def test_client_registration_invalid_uri(self, test_client: httpx.AsyncClient):
self, test_client: httpx.AsyncClient
):
"""Test client registration with invalid URIs.""" """Test client registration with invalid URIs."""
# Invalid redirect_uri format # Invalid redirect_uri format
client_metadata = { client_metadata = {
@@ -696,14 +658,11 @@ class TestAuthEndpoints:
assert "error" in error_data assert "error" in error_data
assert error_data["error"] == "invalid_client_metadata" assert error_data["error"] == "invalid_client_metadata"
assert error_data["error_description"] == ( assert error_data["error_description"] == (
"redirect_uris.0: Input should be a valid URL, " "redirect_uris.0: Input should be a valid URL, " "relative URL without a base"
"relative URL without a base"
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_registration_empty_redirect_uris( async def test_client_registration_empty_redirect_uris(self, test_client: httpx.AsyncClient):
self, test_client: httpx.AsyncClient
):
"""Test client registration with empty redirect_uris array.""" """Test client registration with empty redirect_uris array."""
client_metadata = { client_metadata = {
"redirect_uris": [], # Empty array "redirect_uris": [], # Empty array
@@ -719,8 +678,7 @@ class TestAuthEndpoints:
assert "error" in error_data assert "error" in error_data
assert error_data["error"] == "invalid_client_metadata" assert error_data["error"] == "invalid_client_metadata"
assert ( assert (
error_data["error_description"] error_data["error_description"] == "redirect_uris: List should have at least 1 item after validation, not 0"
== "redirect_uris: List should have at least 1 item after validation, not 0"
) )
@pytest.mark.anyio @pytest.mark.anyio
@@ -875,12 +833,7 @@ class TestAuthEndpoints:
assert response.status_code == 200 assert response.status_code == 200
# Verify that the token was revoked # Verify that the token was revoked
assert ( assert await mock_oauth_provider.load_access_token(new_token_response["access_token"]) is None
await mock_oauth_provider.load_access_token(
new_token_response["access_token"]
)
is None
)
@pytest.mark.anyio @pytest.mark.anyio
async def test_revoke_invalid_token(self, test_client, registered_client): async def test_revoke_invalid_token(self, test_client, registered_client):
@@ -913,9 +866,7 @@ class TestAuthEndpoints:
assert "token_type_hint" in error_response["error_description"] assert "token_type_hint" in error_response["error_description"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_registration_disallowed_scopes( async def test_client_registration_disallowed_scopes(self, test_client: httpx.AsyncClient):
self, test_client: httpx.AsyncClient
):
"""Test client registration with scopes that are not allowed.""" """Test client registration with scopes that are not allowed."""
client_metadata = { client_metadata = {
"redirect_uris": ["https://client.example.com/callback"], "redirect_uris": ["https://client.example.com/callback"],
@@ -955,18 +906,14 @@ class TestAuthEndpoints:
assert client_info["scope"] == "read write" assert client_info["scope"] == "read write"
# Retrieve the client from the store to verify default scopes # Retrieve the client from the store to verify default scopes
registered_client = await mock_oauth_provider.get_client( registered_client = await mock_oauth_provider.get_client(client_info["client_id"])
client_info["client_id"]
)
assert registered_client is not None assert registered_client is not None
# Check that default scopes were applied # Check that default scopes were applied
assert registered_client.scope == "read write" assert registered_client.scope == "read write"
@pytest.mark.anyio @pytest.mark.anyio
async def test_client_registration_invalid_grant_type( async def test_client_registration_invalid_grant_type(self, test_client: httpx.AsyncClient):
self, test_client: httpx.AsyncClient
):
client_metadata = { client_metadata = {
"redirect_uris": ["https://client.example.com/callback"], "redirect_uris": ["https://client.example.com/callback"],
"client_name": "Test Client", "client_name": "Test Client",
@@ -981,19 +928,14 @@ class TestAuthEndpoints:
error_data = response.json() error_data = response.json()
assert "error" in error_data assert "error" in error_data
assert error_data["error"] == "invalid_client_metadata" assert error_data["error"] == "invalid_client_metadata"
assert ( assert error_data["error_description"] == "grant_types must be authorization_code and refresh_token"
error_data["error_description"]
== "grant_types must be authorization_code and refresh_token"
)
class TestAuthorizeEndpointErrors: class TestAuthorizeEndpointErrors:
"""Test error handling in the OAuth authorization endpoint.""" """Test error handling in the OAuth authorization endpoint."""
@pytest.mark.anyio @pytest.mark.anyio
async def test_authorize_missing_client_id( async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, pkce_challenge):
self, test_client: httpx.AsyncClient, pkce_challenge
):
"""Test authorization endpoint with missing client_id. """Test authorization endpoint with missing client_id.
According to the OAuth2.0 spec, if client_id is missing, the server should According to the OAuth2.0 spec, if client_id is missing, the server should
@@ -1017,9 +959,7 @@ class TestAuthorizeEndpointErrors:
assert "client_id" in response.text.lower() assert "client_id" in response.text.lower()
@pytest.mark.anyio @pytest.mark.anyio
async def test_authorize_invalid_client_id( async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, pkce_challenge):
self, test_client: httpx.AsyncClient, pkce_challenge
):
"""Test authorization endpoint with invalid client_id. """Test authorization endpoint with invalid client_id.
According to the OAuth2.0 spec, if client_id is invalid, the server should According to the OAuth2.0 spec, if client_id is invalid, the server should
@@ -1202,9 +1142,7 @@ class TestAuthorizeEndpointErrors:
assert query_params["state"][0] == "test_state" assert query_params["state"][0] == "test_state"
@pytest.mark.anyio @pytest.mark.anyio
async def test_authorize_missing_pkce_challenge( async def test_authorize_missing_pkce_challenge(self, test_client: httpx.AsyncClient, registered_client):
self, test_client: httpx.AsyncClient, registered_client
):
"""Test authorization endpoint with missing PKCE code_challenge. """Test authorization endpoint with missing PKCE code_challenge.
Missing PKCE parameters should result in invalid_request error. Missing PKCE parameters should result in invalid_request error.
@@ -1233,9 +1171,7 @@ class TestAuthorizeEndpointErrors:
assert query_params["state"][0] == "test_state" assert query_params["state"][0] == "test_state"
@pytest.mark.anyio @pytest.mark.anyio
async def test_authorize_invalid_scope( async def test_authorize_invalid_scope(self, test_client: httpx.AsyncClient, registered_client, pkce_challenge):
self, test_client: httpx.AsyncClient, registered_client, pkce_challenge
):
"""Test authorization endpoint with invalid scope. """Test authorization endpoint with invalid scope.
Invalid scope should redirect with invalid_scope error. Invalid scope should redirect with invalid_scope error.

View File

@@ -18,9 +18,7 @@ class TestRenderPrompt:
return "Hello, world!" return "Hello, world!"
prompt = Prompt.from_function(fn) prompt = Prompt.from_function(fn)
assert await prompt.render() == [ assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))]
UserMessage(content=TextContent(type="text", text="Hello, world!"))
]
@pytest.mark.anyio @pytest.mark.anyio
async def test_async_fn(self): async def test_async_fn(self):
@@ -28,9 +26,7 @@ class TestRenderPrompt:
return "Hello, world!" return "Hello, world!"
prompt = Prompt.from_function(fn) prompt = Prompt.from_function(fn)
assert await prompt.render() == [ assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))]
UserMessage(content=TextContent(type="text", text="Hello, world!"))
]
@pytest.mark.anyio @pytest.mark.anyio
async def test_fn_with_args(self): async def test_fn_with_args(self):
@@ -39,11 +35,7 @@ class TestRenderPrompt:
prompt = Prompt.from_function(fn) prompt = Prompt.from_function(fn)
assert await prompt.render(arguments={"name": "World"}) == [ assert await prompt.render(arguments={"name": "World"}) == [
UserMessage( UserMessage(content=TextContent(type="text", text="Hello, World! You're 30 years old."))
content=TextContent(
type="text", text="Hello, World! You're 30 years old."
)
)
] ]
@pytest.mark.anyio @pytest.mark.anyio
@@ -61,21 +53,15 @@ class TestRenderPrompt:
return UserMessage(content="Hello, world!") return UserMessage(content="Hello, world!")
prompt = Prompt.from_function(fn) prompt = Prompt.from_function(fn)
assert await prompt.render() == [ assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))]
UserMessage(content=TextContent(type="text", text="Hello, world!"))
]
@pytest.mark.anyio @pytest.mark.anyio
async def test_fn_returns_assistant_message(self): async def test_fn_returns_assistant_message(self):
async def fn() -> AssistantMessage: async def fn() -> AssistantMessage:
return AssistantMessage( return AssistantMessage(content=TextContent(type="text", text="Hello, world!"))
content=TextContent(type="text", text="Hello, world!")
)
prompt = Prompt.from_function(fn) prompt = Prompt.from_function(fn)
assert await prompt.render() == [ assert await prompt.render() == [AssistantMessage(content=TextContent(type="text", text="Hello, world!"))]
AssistantMessage(content=TextContent(type="text", text="Hello, world!"))
]
@pytest.mark.anyio @pytest.mark.anyio
async def test_fn_returns_multiple_messages(self): async def test_fn_returns_multiple_messages(self):
@@ -156,9 +142,7 @@ class TestRenderPrompt:
prompt = Prompt.from_function(fn) prompt = Prompt.from_function(fn)
assert await prompt.render() == [ assert await prompt.render() == [
UserMessage( UserMessage(content=TextContent(type="text", text="Please analyze this file:")),
content=TextContent(type="text", text="Please analyze this file:")
),
UserMessage( UserMessage(
content=EmbeddedResource( content=EmbeddedResource(
type="resource", type="resource",
@@ -169,9 +153,7 @@ class TestRenderPrompt:
), ),
) )
), ),
AssistantMessage( AssistantMessage(content=TextContent(type="text", text="I'll help analyze that file.")),
content=TextContent(type="text", text="I'll help analyze that file.")
),
] ]
@pytest.mark.anyio @pytest.mark.anyio

View File

@@ -72,9 +72,7 @@ class TestPromptManager:
prompt = Prompt.from_function(fn) prompt = Prompt.from_function(fn)
manager.add_prompt(prompt) manager.add_prompt(prompt)
messages = await manager.render_prompt("fn") messages = await manager.render_prompt("fn")
assert messages == [ assert messages == [UserMessage(content=TextContent(type="text", text="Hello, world!"))]
UserMessage(content=TextContent(type="text", text="Hello, world!"))
]
@pytest.mark.anyio @pytest.mark.anyio
async def test_render_prompt_with_args(self): async def test_render_prompt_with_args(self):
@@ -87,9 +85,7 @@ class TestPromptManager:
prompt = Prompt.from_function(fn) prompt = Prompt.from_function(fn)
manager.add_prompt(prompt) manager.add_prompt(prompt)
messages = await manager.render_prompt("fn", arguments={"name": "World"}) messages = await manager.render_prompt("fn", arguments={"name": "World"})
assert messages == [ assert messages == [UserMessage(content=TextContent(type="text", text="Hello, World!"))]
UserMessage(content=TextContent(type="text", text="Hello, World!"))
]
@pytest.mark.anyio @pytest.mark.anyio
async def test_render_unknown_prompt(self): async def test_render_unknown_prompt(self):

View File

@@ -100,9 +100,7 @@ class TestFileResource:
with pytest.raises(ValueError, match="Error reading file"): with pytest.raises(ValueError, match="Error reading file"):
await resource.read() await resource.read()
@pytest.mark.skipif( @pytest.mark.skipif(os.name == "nt", reason="File permissions behave differently on Windows")
os.name == "nt", reason="File permissions behave differently on Windows"
)
@pytest.mark.anyio @pytest.mark.anyio
async def test_permission_error(self, temp_file: Path): async def test_permission_error(self, temp_file: Path):
"""Test reading a file without permissions.""" """Test reading a file without permissions."""

View File

@@ -28,9 +28,7 @@ def complex_arguments_fn(
# list[str] | str is an interesting case because if it comes in as JSON like # list[str] | str is an interesting case because if it comes in as JSON like
# "[\"a\", \"b\"]" then it will be naively parsed as a string. # "[\"a\", \"b\"]" then it will be naively parsed as a string.
list_str_or_str: list[str] | str, list_str_or_str: list[str] | str,
an_int_annotated_with_field: Annotated[ an_int_annotated_with_field: Annotated[int, Field(description="An int with a field")],
int, Field(description="An int with a field")
],
an_int_annotated_with_field_and_others: Annotated[ an_int_annotated_with_field_and_others: Annotated[
int, int,
str, # Should be ignored, really str, # Should be ignored, really
@@ -42,9 +40,7 @@ def complex_arguments_fn(
"123", "123",
456, 456,
], ],
field_with_default_via_field_annotation_before_nondefault_arg: Annotated[ field_with_default_via_field_annotation_before_nondefault_arg: Annotated[int, Field(1)],
int, Field(1)
],
unannotated, unannotated,
my_model_a: SomeInputModelA, my_model_a: SomeInputModelA,
my_model_a_forward_ref: "SomeInputModelA", my_model_a_forward_ref: "SomeInputModelA",
@@ -179,9 +175,7 @@ def test_str_vs_list_str():
def test_skip_names(): def test_skip_names():
"""Test that skipped parameters are not included in the model""" """Test that skipped parameters are not included in the model"""
def func_with_many_params( def func_with_many_params(keep_this: int, skip_this: str, also_keep: float, also_skip: bool):
keep_this: int, skip_this: str, also_keep: float, also_skip: bool
):
return keep_this, skip_this, also_keep, also_skip return keep_this, skip_this, also_keep, also_skip
# Skip some parameters # Skip some parameters

View File

@@ -130,11 +130,7 @@ def make_everything_fastmcp() -> FastMCP:
# Request sampling from the client # Request sampling from the client
result = await ctx.session.create_message( result = await ctx.session.create_message(
messages=[ messages=[SamplingMessage(role="user", content=TextContent(type="text", text=prompt))],
SamplingMessage(
role="user", content=TextContent(type="text", text=prompt)
)
],
max_tokens=100, max_tokens=100,
temperature=0.7, temperature=0.7,
) )
@@ -278,11 +274,7 @@ def make_fastmcp_stateless_http_app():
def run_server(server_port: int) -> None: def run_server(server_port: int) -> None:
"""Run the server.""" """Run the server."""
_, app = make_fastmcp_app() _, app = make_fastmcp_app()
server = uvicorn.Server( server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"Starting server on port {server_port}") print(f"Starting server on port {server_port}")
server.run() server.run()
@@ -290,11 +282,7 @@ def run_server(server_port: int) -> None:
def run_everything_legacy_sse_http_server(server_port: int) -> None: def run_everything_legacy_sse_http_server(server_port: int) -> None:
"""Run the comprehensive server with all features.""" """Run the comprehensive server with all features."""
_, app = make_everything_fastmcp_app() _, app = make_everything_fastmcp_app()
server = uvicorn.Server( server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"Starting comprehensive server on port {server_port}") print(f"Starting comprehensive server on port {server_port}")
server.run() server.run()
@@ -302,11 +290,7 @@ def run_everything_legacy_sse_http_server(server_port: int) -> None:
def run_streamable_http_server(server_port: int) -> None: def run_streamable_http_server(server_port: int) -> None:
"""Run the StreamableHTTP server.""" """Run the StreamableHTTP server."""
_, app = make_fastmcp_streamable_http_app() _, app = make_fastmcp_streamable_http_app()
server = uvicorn.Server( server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"Starting StreamableHTTP server on port {server_port}") print(f"Starting StreamableHTTP server on port {server_port}")
server.run() server.run()
@@ -314,11 +298,7 @@ def run_streamable_http_server(server_port: int) -> None:
def run_everything_server(server_port: int) -> None: def run_everything_server(server_port: int) -> None:
"""Run the comprehensive StreamableHTTP server with all features.""" """Run the comprehensive StreamableHTTP server with all features."""
_, app = make_everything_fastmcp_streamable_http_app() _, app = make_everything_fastmcp_streamable_http_app()
server = uvicorn.Server( server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"Starting comprehensive StreamableHTTP server on port {server_port}") print(f"Starting comprehensive StreamableHTTP server on port {server_port}")
server.run() server.run()
@@ -326,11 +306,7 @@ def run_everything_server(server_port: int) -> None:
def run_stateless_http_server(server_port: int) -> None: def run_stateless_http_server(server_port: int) -> None:
"""Run the stateless StreamableHTTP server.""" """Run the stateless StreamableHTTP server."""
_, app = make_fastmcp_stateless_http_app() _, app = make_fastmcp_stateless_http_app()
server = uvicorn.Server( server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"Starting stateless StreamableHTTP server on port {server_port}") print(f"Starting stateless StreamableHTTP server on port {server_port}")
server.run() server.run()
@@ -369,9 +345,7 @@ def server(server_port: int) -> Generator[None, None, None]:
@pytest.fixture() @pytest.fixture()
def streamable_http_server(http_server_port: int) -> Generator[None, None, None]: def streamable_http_server(http_server_port: int) -> Generator[None, None, None]:
"""Start the StreamableHTTP server in a separate process.""" """Start the StreamableHTTP server in a separate process."""
proc = multiprocessing.Process( proc = multiprocessing.Process(target=run_streamable_http_server, args=(http_server_port,), daemon=True)
target=run_streamable_http_server, args=(http_server_port,), daemon=True
)
print("Starting StreamableHTTP server process") print("Starting StreamableHTTP server process")
proc.start() proc.start()
@@ -388,9 +362,7 @@ def streamable_http_server(http_server_port: int) -> Generator[None, None, None]
time.sleep(0.1) time.sleep(0.1)
attempt += 1 attempt += 1
else: else:
raise RuntimeError( raise RuntimeError(f"StreamableHTTP server failed to start after {max_attempts} attempts")
f"StreamableHTTP server failed to start after {max_attempts} attempts"
)
yield yield
@@ -427,9 +399,7 @@ def stateless_http_server(
time.sleep(0.1) time.sleep(0.1)
attempt += 1 attempt += 1
else: else:
raise RuntimeError( raise RuntimeError(f"Stateless server failed to start after {max_attempts} attempts")
f"Stateless server failed to start after {max_attempts} attempts"
)
yield yield
@@ -459,9 +429,7 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
@pytest.mark.anyio @pytest.mark.anyio
async def test_fastmcp_streamable_http( async def test_fastmcp_streamable_http(streamable_http_server: None, http_server_url: str) -> None:
streamable_http_server: None, http_server_url: str
) -> None:
"""Test that FastMCP works with StreamableHTTP transport.""" """Test that FastMCP works with StreamableHTTP transport."""
# Connect to the server using StreamableHTTP # Connect to the server using StreamableHTTP
async with streamablehttp_client(http_server_url + "/mcp") as ( async with streamablehttp_client(http_server_url + "/mcp") as (
@@ -484,9 +452,7 @@ async def test_fastmcp_streamable_http(
@pytest.mark.anyio @pytest.mark.anyio
async def test_fastmcp_stateless_streamable_http( async def test_fastmcp_stateless_streamable_http(stateless_http_server: None, stateless_http_server_url: str) -> None:
stateless_http_server: None, stateless_http_server_url: str
) -> None:
"""Test that FastMCP works with stateless StreamableHTTP transport.""" """Test that FastMCP works with stateless StreamableHTTP transport."""
# Connect to the server using StreamableHTTP # Connect to the server using StreamableHTTP
async with streamablehttp_client(stateless_http_server_url + "/mcp") as ( async with streamablehttp_client(stateless_http_server_url + "/mcp") as (
@@ -562,9 +528,7 @@ def everything_server(everything_server_port: int) -> Generator[None, None, None
time.sleep(0.1) time.sleep(0.1)
attempt += 1 attempt += 1
else: else:
raise RuntimeError( raise RuntimeError(f"Comprehensive server failed to start after {max_attempts} attempts")
f"Comprehensive server failed to start after {max_attempts} attempts"
)
yield yield
@@ -601,10 +565,7 @@ def everything_streamable_http_server(
time.sleep(0.1) time.sleep(0.1)
attempt += 1 attempt += 1
else: else:
raise RuntimeError( raise RuntimeError(f"Comprehensive StreamableHTTP server failed to start after " f"{max_attempts} attempts")
f"Comprehensive StreamableHTTP server failed to start after "
f"{max_attempts} attempts"
)
yield yield
@@ -648,9 +609,7 @@ class NotificationCollector:
await self.handle_tool_list_changed(message.root.params) await self.handle_tool_list_changed(message.root.params)
async def call_all_mcp_features( async def call_all_mcp_features(session: ClientSession, collector: NotificationCollector) -> None:
session: ClientSession, collector: NotificationCollector
) -> None:
""" """
Test all MCP features using the provided session. Test all MCP features using the provided session.
@@ -680,9 +639,7 @@ async def call_all_mcp_features(
# Test progress callback functionality # Test progress callback functionality
progress_updates = [] progress_updates = []
async def progress_callback( async def progress_callback(progress: float, total: float | None, message: str | None) -> None:
progress: float, total: float | None, message: str | None
) -> None:
"""Collect progress updates for testing (async version).""" """Collect progress updates for testing (async version)."""
progress_updates.append((progress, total, message)) progress_updates.append((progress, total, message))
print(f"Progress: {progress}/{total} - {message}") print(f"Progress: {progress}/{total} - {message}")
@@ -726,19 +683,12 @@ async def call_all_mcp_features(
# Verify we received log messages from the sampling tool # Verify we received log messages from the sampling tool
assert len(collector.log_messages) > 0 assert len(collector.log_messages) > 0
assert any( assert any("Requesting sampling for prompt" in msg.data for msg in collector.log_messages)
"Requesting sampling for prompt" in msg.data for msg in collector.log_messages assert any("Received sampling result from model" in msg.data for msg in collector.log_messages)
)
assert any(
"Received sampling result from model" in msg.data
for msg in collector.log_messages
)
# 4. Test notification tool # 4. Test notification tool
notification_message = "test_notifications" notification_message = "test_notifications"
notification_result = await session.call_tool( notification_result = await session.call_tool("notification_tool", {"message": notification_message})
"notification_tool", {"message": notification_message}
)
assert len(notification_result.content) == 1 assert len(notification_result.content) == 1
assert isinstance(notification_result.content[0], TextContent) assert isinstance(notification_result.content[0], TextContent)
assert "Sent notifications and logs" in notification_result.content[0].text assert "Sent notifications and logs" in notification_result.content[0].text
@@ -773,36 +723,24 @@ async def call_all_mcp_features(
# 2. Dynamic resource # 2. Dynamic resource
resource_category = "test" resource_category = "test"
dynamic_content = await session.read_resource( dynamic_content = await session.read_resource(AnyUrl(f"resource://dynamic/{resource_category}"))
AnyUrl(f"resource://dynamic/{resource_category}")
)
assert isinstance(dynamic_content, ReadResourceResult) assert isinstance(dynamic_content, ReadResourceResult)
assert len(dynamic_content.contents) == 1 assert len(dynamic_content.contents) == 1
assert isinstance(dynamic_content.contents[0], TextResourceContents) assert isinstance(dynamic_content.contents[0], TextResourceContents)
assert ( assert f"Dynamic resource content for category: {resource_category}" in dynamic_content.contents[0].text
f"Dynamic resource content for category: {resource_category}"
in dynamic_content.contents[0].text
)
# 3. Template resource # 3. Template resource
resource_id = "456" resource_id = "456"
template_content = await session.read_resource( template_content = await session.read_resource(AnyUrl(f"resource://template/{resource_id}/data"))
AnyUrl(f"resource://template/{resource_id}/data")
)
assert isinstance(template_content, ReadResourceResult) assert isinstance(template_content, ReadResourceResult)
assert len(template_content.contents) == 1 assert len(template_content.contents) == 1
assert isinstance(template_content.contents[0], TextResourceContents) assert isinstance(template_content.contents[0], TextResourceContents)
assert ( assert f"Template resource data for ID: {resource_id}" in template_content.contents[0].text
f"Template resource data for ID: {resource_id}"
in template_content.contents[0].text
)
# Test prompts # Test prompts
# 1. Simple prompt # 1. Simple prompt
prompts = await session.list_prompts() prompts = await session.list_prompts()
simple_prompt = next( simple_prompt = next((p for p in prompts.prompts if p.name == "simple_prompt"), None)
(p for p in prompts.prompts if p.name == "simple_prompt"), None
)
assert simple_prompt is not None assert simple_prompt is not None
prompt_topic = "AI" prompt_topic = "AI"
@@ -812,16 +750,12 @@ async def call_all_mcp_features(
# The actual message structure depends on the prompt implementation # The actual message structure depends on the prompt implementation
# 2. Complex prompt # 2. Complex prompt
complex_prompt = next( complex_prompt = next((p for p in prompts.prompts if p.name == "complex_prompt"), None)
(p for p in prompts.prompts if p.name == "complex_prompt"), None
)
assert complex_prompt is not None assert complex_prompt is not None
query = "What is AI?" query = "What is AI?"
context = "technical" context = "technical"
complex_result = await session.get_prompt( complex_result = await session.get_prompt("complex_prompt", {"user_query": query, "context": context})
"complex_prompt", {"user_query": query, "context": context}
)
assert isinstance(complex_result, GetPromptResult) assert isinstance(complex_result, GetPromptResult)
assert len(complex_result.messages) >= 1 assert len(complex_result.messages) >= 1
@@ -837,9 +771,7 @@ async def call_all_mcp_features(
print(f"Received headers: {headers_data}") print(f"Received headers: {headers_data}")
# Test 6: Call tool that returns full context # Test 6: Call tool that returns full context
context_result = await session.call_tool( context_result = await session.call_tool("echo_context", {"custom_request_id": "test-123"})
"echo_context", {"custom_request_id": "test-123"}
)
assert len(context_result.content) == 1 assert len(context_result.content) == 1
assert isinstance(context_result.content[0], TextContent) assert isinstance(context_result.content[0], TextContent)
@@ -871,9 +803,7 @@ async def sampling_callback(
@pytest.mark.anyio @pytest.mark.anyio
async def test_fastmcp_all_features_sse( async def test_fastmcp_all_features_sse(everything_server: None, everything_server_url: str) -> None:
everything_server: None, everything_server_url: str
) -> None:
"""Test all MCP features work correctly with SSE transport.""" """Test all MCP features work correctly with SSE transport."""
# Create notification collector # Create notification collector

View File

@@ -59,9 +59,7 @@ class TestServer:
"""Test SSE app creation with different mount paths.""" """Test SSE app creation with different mount paths."""
# Test with default mount path # Test with default mount path
mcp = FastMCP() mcp = FastMCP()
with patch.object( with patch.object(mcp, "_normalize_path", return_value="/messages/") as mock_normalize:
mcp, "_normalize_path", return_value="/messages/"
) as mock_normalize:
mcp.sse_app() mcp.sse_app()
# Verify _normalize_path was called with correct args # Verify _normalize_path was called with correct args
mock_normalize.assert_called_once_with("/", "/messages/") mock_normalize.assert_called_once_with("/", "/messages/")
@@ -69,18 +67,14 @@ class TestServer:
# Test with custom mount path in settings # Test with custom mount path in settings
mcp = FastMCP() mcp = FastMCP()
mcp.settings.mount_path = "/custom" mcp.settings.mount_path = "/custom"
with patch.object( with patch.object(mcp, "_normalize_path", return_value="/custom/messages/") as mock_normalize:
mcp, "_normalize_path", return_value="/custom/messages/"
) as mock_normalize:
mcp.sse_app() mcp.sse_app()
# Verify _normalize_path was called with correct args # Verify _normalize_path was called with correct args
mock_normalize.assert_called_once_with("/custom", "/messages/") mock_normalize.assert_called_once_with("/custom", "/messages/")
# Test with mount_path parameter # Test with mount_path parameter
mcp = FastMCP() mcp = FastMCP()
with patch.object( with patch.object(mcp, "_normalize_path", return_value="/param/messages/") as mock_normalize:
mcp, "_normalize_path", return_value="/param/messages/"
) as mock_normalize:
mcp.sse_app(mount_path="/param") mcp.sse_app(mount_path="/param")
# Verify _normalize_path was called with correct args # Verify _normalize_path was called with correct args
mock_normalize.assert_called_once_with("/param", "/messages/") mock_normalize.assert_called_once_with("/param", "/messages/")
@@ -103,9 +97,7 @@ class TestServer:
# Verify path values # Verify path values
assert sse_routes[0].path == "/sse", "SSE route path should be /sse" assert sse_routes[0].path == "/sse", "SSE route path should be /sse"
assert ( assert mount_routes[0].path == "/messages", "Mount route path should be /messages"
mount_routes[0].path == "/messages"
), "Mount route path should be /messages"
# Test with mount path as parameter # Test with mount path as parameter
mcp = FastMCP() mcp = FastMCP()
@@ -121,20 +113,14 @@ class TestServer:
# Verify path values # Verify path values
assert sse_routes[0].path == "/sse", "SSE route path should be /sse" assert sse_routes[0].path == "/sse", "SSE route path should be /sse"
assert ( assert mount_routes[0].path == "/messages", "Mount route path should be /messages"
mount_routes[0].path == "/messages"
), "Mount route path should be /messages"
@pytest.mark.anyio @pytest.mark.anyio
async def test_non_ascii_description(self): async def test_non_ascii_description(self):
"""Test that FastMCP handles non-ASCII characters in descriptions correctly""" """Test that FastMCP handles non-ASCII characters in descriptions correctly"""
mcp = FastMCP() mcp = FastMCP()
@mcp.tool( @mcp.tool(description=("🌟 This tool uses emojis and UTF-8 characters: á é í ó ú ñ 漢字 🎉"))
description=(
"🌟 This tool uses emojis and UTF-8 characters: á é í ó ú ñ 漢字 🎉"
)
)
def hello_world(name: str = "世界") -> str: def hello_world(name: str = "世界") -> str:
return f"¡Hola, {name}! 👋" return f"¡Hola, {name}! 👋"
@@ -187,9 +173,7 @@ class TestServer:
async def test_add_resource_decorator_incorrect_usage(self): async def test_add_resource_decorator_incorrect_usage(self):
mcp = FastMCP() mcp = FastMCP()
with pytest.raises( with pytest.raises(TypeError, match="The @resource decorator was used incorrectly"):
TypeError, match="The @resource decorator was used incorrectly"
):
@mcp.resource # Missing parentheses #type: ignore @mcp.resource # Missing parentheses #type: ignore
def get_data(x: str) -> str: def get_data(x: str) -> str:
@@ -373,9 +357,7 @@ class TestServerResources:
def get_text(): def get_text():
return "Hello, world!" return "Hello, world!"
resource = FunctionResource( resource = FunctionResource(uri=AnyUrl("resource://test"), name="test", fn=get_text)
uri=AnyUrl("resource://test"), name="test", fn=get_text
)
mcp.add_resource(resource) mcp.add_resource(resource)
async with client_session(mcp._mcp_server) as client: async with client_session(mcp._mcp_server) as client:
@@ -411,9 +393,7 @@ class TestServerResources:
text_file = tmp_path / "test.txt" text_file = tmp_path / "test.txt"
text_file.write_text("Hello from file!") text_file.write_text("Hello from file!")
resource = FileResource( resource = FileResource(uri=AnyUrl("file://test.txt"), name="test.txt", path=text_file)
uri=AnyUrl("file://test.txt"), name="test.txt", path=text_file
)
mcp.add_resource(resource) mcp.add_resource(resource)
async with client_session(mcp._mcp_server) as client: async with client_session(mcp._mcp_server) as client:
@@ -440,10 +420,7 @@ class TestServerResources:
async with client_session(mcp._mcp_server) as client: async with client_session(mcp._mcp_server) as client:
result = await client.read_resource(AnyUrl("file://test.bin")) result = await client.read_resource(AnyUrl("file://test.bin"))
assert isinstance(result.contents[0], BlobResourceContents) assert isinstance(result.contents[0], BlobResourceContents)
assert ( assert result.contents[0].blob == base64.b64encode(b"Binary file data").decode()
result.contents[0].blob
== base64.b64encode(b"Binary file data").decode()
)
@pytest.mark.anyio @pytest.mark.anyio
async def test_function_resource(self): async def test_function_resource(self):
@@ -532,9 +509,7 @@ class TestServerResourceTemplates:
return f"Data for {org}/{repo}" return f"Data for {org}/{repo}"
async with client_session(mcp._mcp_server) as client: async with client_session(mcp._mcp_server) as client:
result = await client.read_resource( result = await client.read_resource(AnyUrl("resource://cursor/fastmcp/data"))
AnyUrl("resource://cursor/fastmcp/data")
)
assert isinstance(result.contents[0], TextResourceContents) assert isinstance(result.contents[0], TextResourceContents)
assert result.contents[0].text == "Data for cursor/fastmcp" assert result.contents[0].text == "Data for cursor/fastmcp"

View File

@@ -147,9 +147,7 @@ class TestAddTools:
def test_add_lambda_with_no_name(self): def test_add_lambda_with_no_name(self):
manager = ToolManager() manager = ToolManager()
with pytest.raises( with pytest.raises(ValueError, match="You must provide a name for lambda functions"):
ValueError, match="You must provide a name for lambda functions"
):
manager.add_tool(lambda x: x) manager.add_tool(lambda x: x)
def test_warn_on_duplicate_tools(self, caplog): def test_warn_on_duplicate_tools(self, caplog):
@@ -346,9 +344,7 @@ class TestContextHandling:
tool = manager.add_tool(tool_without_context) tool = manager.add_tool(tool_without_context)
assert tool.context_kwarg is None assert tool.context_kwarg is None
def tool_with_parametrized_context( def tool_with_parametrized_context(x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT]) -> str:
x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT]
) -> str:
return str(x) return str(x)
tool = manager.add_tool(tool_with_parametrized_context) tool = manager.add_tool(tool_with_parametrized_context)

View File

@@ -10,13 +10,7 @@ from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession from mcp.server.session import ServerSession
from mcp.shared.message import SessionMessage from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder from mcp.shared.session import RequestResponder
from mcp.types import ( from mcp.types import ClientResult, ServerNotification, ServerRequest, Tool, ToolAnnotations
ClientResult,
ServerNotification,
ServerRequest,
Tool,
ToolAnnotations,
)
@pytest.mark.anyio @pytest.mark.anyio
@@ -45,18 +39,12 @@ async def test_lowlevel_server_tool_annotations():
) )
] ]
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10)
SessionMessage client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10)
](10)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](10)
# Message handler for client # Message handler for client
async def message_handler( async def message_handler(
message: RequestResponder[ServerRequest, ClientResult] message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception,
| ServerNotification
| Exception,
) -> None: ) -> None:
if isinstance(message, Exception): if isinstance(message, Exception):
raise message raise message

View File

@@ -56,11 +56,7 @@ async def test_read_resource_binary(temp_file: Path):
@server.read_resource() @server.read_resource()
async def read_resource(uri: AnyUrl) -> Iterable[ReadResourceContents]: async def read_resource(uri: AnyUrl) -> Iterable[ReadResourceContents]:
return [ return [ReadResourceContents(content=b"Hello World", mime_type="application/octet-stream")]
ReadResourceContents(
content=b"Hello World", mime_type="application/octet-stream"
)
]
# Get the handler directly from the server # Get the handler directly from the server
handler = server.request_handlers[types.ReadResourceRequest] handler = server.request_handlers[types.ReadResourceRequest]

View File

@@ -20,18 +20,12 @@ from mcp.types import (
@pytest.mark.anyio @pytest.mark.anyio
async def test_server_session_initialize(): async def test_server_session_initialize():
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
SessionMessage client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
# Create a message handler to catch exceptions # Create a message handler to catch exceptions
async def message_handler( async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
| types.ServerNotification
| Exception,
) -> None: ) -> None:
if isinstance(message, Exception): if isinstance(message, Exception):
raise message raise message
@@ -54,9 +48,7 @@ async def test_server_session_initialize():
if isinstance(message, Exception): if isinstance(message, Exception):
raise message raise message
if isinstance(message, ClientNotification) and isinstance( if isinstance(message, ClientNotification) and isinstance(message.root, InitializedNotification):
message.root, InitializedNotification
):
received_initialized = True received_initialized = True
return return
@@ -111,12 +103,8 @@ async def test_server_capabilities():
@pytest.mark.anyio @pytest.mark.anyio
async def test_server_session_initialize_with_older_protocol_version(): async def test_server_session_initialize_with_older_protocol_version():
"""Test that server accepts and responds with older protocol (2024-11-05).""" """Test that server accepts and responds with older protocol (2024-11-05)."""
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
SessionMessage client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage | Exception
](1)
received_initialized = False received_initialized = False
received_protocol_version = None received_protocol_version = None
@@ -137,9 +125,7 @@ async def test_server_session_initialize_with_older_protocol_version():
if isinstance(message, Exception): if isinstance(message, Exception):
raise message raise message
if isinstance(message, types.ClientNotification) and isinstance( if isinstance(message, types.ClientNotification) and isinstance(message.root, InitializedNotification):
message.root, InitializedNotification
):
received_initialized = True received_initialized = True
return return
@@ -157,9 +143,7 @@ async def test_server_session_initialize_with_older_protocol_version():
params=types.InitializeRequestParams( params=types.InitializeRequestParams(
protocolVersion="2024-11-05", protocolVersion="2024-11-05",
capabilities=types.ClientCapabilities(), capabilities=types.ClientCapabilities(),
clientInfo=types.Implementation( clientInfo=types.Implementation(name="test-client", version="1.0.0"),
name="test-client", version="1.0.0"
),
).model_dump(by_alias=True, mode="json", exclude_none=True), ).model_dump(by_alias=True, mode="json", exclude_none=True),
) )
) )

View File

@@ -22,9 +22,10 @@ async def test_stdio_server():
stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n") stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n")
stdin.seek(0) stdin.seek(0)
async with stdio_server( async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as (
stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout) read_stream,
) as (read_stream, write_stream): write_stream,
):
received_messages = [] received_messages = []
async with read_stream: async with read_stream:
async for message in read_stream: async for message in read_stream:
@@ -36,12 +37,8 @@ async def test_stdio_server():
# Verify received messages # Verify received messages
assert len(received_messages) == 2 assert len(received_messages) == 2
assert received_messages[0] == JSONRPCMessage( assert received_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") assert received_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}))
)
assert received_messages[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
)
# Test sending responses from the server # Test sending responses from the server
responses = [ responses = [
@@ -58,13 +55,7 @@ async def test_stdio_server():
output_lines = stdout.readlines() output_lines = stdout.readlines()
assert len(output_lines) == 2 assert len(output_lines) == 2
received_responses = [ received_responses = [JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines]
JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines
]
assert len(received_responses) == 2 assert len(received_responses) == 2
assert received_responses[0] == JSONRPCMessage( assert received_responses[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping"))
root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") assert received_responses[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}))
)
assert received_responses[1] == JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})
)

View File

@@ -22,10 +22,7 @@ async def test_run_can_only_be_called_once():
async with manager.run(): async with manager.run():
pass pass
assert ( assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(excinfo.value)
"StreamableHTTPSessionManager .run() can only be called once per instance"
in str(excinfo.value)
)
@pytest.mark.anyio @pytest.mark.anyio
@@ -51,10 +48,7 @@ async def test_run_prevents_concurrent_calls():
# One should succeed, one should fail # One should succeed, one should fail
assert len(errors) == 1 assert len(errors) == 1
assert ( assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(errors[0])
"StreamableHTTPSessionManager .run() can only be called once per instance"
in str(errors[0])
)
@pytest.mark.anyio @pytest.mark.anyio
@@ -76,6 +70,4 @@ async def test_handle_request_without_run_raises_error():
with pytest.raises(RuntimeError) as excinfo: with pytest.raises(RuntimeError) as excinfo:
await manager.handle_request(scope, receive, send) await manager.handle_request(scope, receive, send)
assert "Task group is not initialized. Make sure to use run()." in str( assert "Task group is not initialized. Make sure to use run()." in str(excinfo.value)
excinfo.value
)

View File

@@ -22,12 +22,8 @@ from mcp.shared.session import (
async def test_bidirectional_progress_notifications(): async def test_bidirectional_progress_notifications():
"""Test that both client and server can send progress notifications.""" """Test that both client and server can send progress notifications."""
# Create memory streams for client/server # Create memory streams for client/server
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5)
SessionMessage client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5)
](5)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](5)
# Run a server session so we can send progress updates in tool # Run a server session so we can send progress updates in tool
async def run_server(): async def run_server():
@@ -134,9 +130,7 @@ async def test_bidirectional_progress_notifications():
# Client message handler to store progress notifications # Client message handler to store progress notifications
async def handle_client_message( async def handle_client_message(
message: RequestResponder[types.ServerRequest, types.ClientResult] message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
| types.ServerNotification
| Exception,
) -> None: ) -> None:
if isinstance(message, Exception): if isinstance(message, Exception):
raise message raise message
@@ -172,9 +166,7 @@ async def test_bidirectional_progress_notifications():
await client_session.list_tools() await client_session.list_tools()
# Call test_tool with progress token # Call test_tool with progress token
await client_session.call_tool( await client_session.call_tool("test_tool", {"_meta": {"progressToken": client_progress_token}})
"test_tool", {"_meta": {"progressToken": client_progress_token}}
)
# Send progress notifications from client to server # Send progress notifications from client to server
await client_session.send_progress_notification( await client_session.send_progress_notification(
@@ -221,12 +213,8 @@ async def test_bidirectional_progress_notifications():
async def test_progress_context_manager(): async def test_progress_context_manager():
"""Test client using progress context manager for sending progress notifications.""" """Test client using progress context manager for sending progress notifications."""
# Create memory streams for client/server # Create memory streams for client/server
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5)
SessionMessage client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5)
](5)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](5)
# Track progress updates # Track progress updates
server_progress_updates = [] server_progress_updates = []
@@ -270,9 +258,7 @@ async def test_progress_context_manager():
# Client message handler # Client message handler
async def handle_client_message( async def handle_client_message(
message: RequestResponder[types.ServerRequest, types.ClientResult] message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
| types.ServerNotification
| Exception,
) -> None: ) -> None:
if isinstance(message, Exception): if isinstance(message, Exception):
raise message raise message

View File

@@ -90,9 +90,7 @@ async def test_request_cancellation():
ClientRequest( ClientRequest(
types.CallToolRequest( types.CallToolRequest(
method="tools/call", method="tools/call",
params=types.CallToolRequestParams( params=types.CallToolRequestParams(name="slow_tool", arguments={}),
name="slow_tool", arguments={}
),
) )
), ),
types.CallToolResult, types.CallToolResult,
@@ -103,9 +101,7 @@ async def test_request_cancellation():
assert "Request cancelled" in str(e) assert "Request cancelled" in str(e)
ev_cancelled.set() ev_cancelled.set()
async with create_connected_server_and_client_session( async with create_connected_server_and_client_session(make_server()) as client_session:
make_server()
) as client_session:
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
tg.start_soon(make_request, client_session) tg.start_soon(make_request, client_session)

View File

@@ -60,11 +60,7 @@ class ServerTest(Server):
await anyio.sleep(2.0) await anyio.sleep(2.0)
return f"Slow response from {uri.host}" return f"Slow response from {uri.host}"
raise McpError( raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found"))
error=ErrorData(
code=404, message="OOPS! no resource with that URI was found"
)
)
@self.list_tools() @self.list_tools()
async def handle_list_tools() -> list[Tool]: async def handle_list_tools() -> list[Tool]:
@@ -88,12 +84,8 @@ def make_server_app() -> Starlette:
server = ServerTest() server = ServerTest()
async def handle_sse(request: Request) -> Response: async def handle_sse(request: Request) -> Response:
async with sse.connect_sse( async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
request.scope, request.receive, request._send await server.run(streams[0], streams[1], server.create_initialization_options())
) as streams:
await server.run(
streams[0], streams[1], server.create_initialization_options()
)
return Response() return Response()
app = Starlette( app = Starlette(
@@ -108,11 +100,7 @@ def make_server_app() -> Starlette:
def run_server(server_port: int) -> None: def run_server(server_port: int) -> None:
app = make_server_app() app = make_server_app()
server = uvicorn.Server( server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"starting server on {server_port}") print(f"starting server on {server_port}")
server.run() server.run()
@@ -124,9 +112,7 @@ def run_server(server_port: int) -> None:
@pytest.fixture() @pytest.fixture()
def server(server_port: int) -> Generator[None, None, None]: def server(server_port: int) -> Generator[None, None, None]:
proc = multiprocessing.Process( proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True)
target=run_server, kwargs={"server_port": server_port}, daemon=True
)
print("starting process") print("starting process")
proc.start() proc.start()
@@ -171,10 +157,7 @@ async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None:
async def connection_test() -> None: async def connection_test() -> None:
async with http_client.stream("GET", "/sse") as response: async with http_client.stream("GET", "/sse") as response:
assert response.status_code == 200 assert response.status_code == 200
assert ( assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
response.headers["content-type"]
== "text/event-stream; charset=utf-8"
)
line_number = 0 line_number = 0
async for line in response.aiter_lines(): async for line in response.aiter_lines():
@@ -206,9 +189,7 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non
@pytest.fixture @pytest.fixture
async def initialized_sse_client_session( async def initialized_sse_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]:
server, server_url: str
) -> AsyncGenerator[ClientSession, None]:
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
async with ClientSession(*streams) as session: async with ClientSession(*streams) as session:
await session.initialize() await session.initialize()
@@ -236,9 +217,7 @@ async def test_sse_client_exception_handling(
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.skip( @pytest.mark.skip("this test highlights a possible bug in SSE read timeout exception handling")
"this test highlights a possible bug in SSE read timeout exception handling"
)
async def test_sse_client_timeout( async def test_sse_client_timeout(
initialized_sse_client_session: ClientSession, initialized_sse_client_session: ClientSession,
) -> None: ) -> None:
@@ -260,11 +239,7 @@ async def test_sse_client_timeout(
def run_mounted_server(server_port: int) -> None: def run_mounted_server(server_port: int) -> None:
app = make_server_app() app = make_server_app()
main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) main_app = Starlette(routes=[Mount("/mounted_app", app=app)])
server = uvicorn.Server( server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error"))
config=uvicorn.Config(
app=main_app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"starting server on {server_port}") print(f"starting server on {server_port}")
server.run() server.run()
@@ -276,9 +251,7 @@ def run_mounted_server(server_port: int) -> None:
@pytest.fixture() @pytest.fixture()
def mounted_server(server_port: int) -> Generator[None, None, None]: def mounted_server(server_port: int) -> Generator[None, None, None]:
proc = multiprocessing.Process( proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True)
target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True
)
print("starting process") print("starting process")
proc.start() proc.start()
@@ -308,9 +281,7 @@ def mounted_server(server_port: int) -> Generator[None, None, None]:
@pytest.mark.anyio @pytest.mark.anyio
async def test_sse_client_basic_connection_mounted_app( async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None:
mounted_server: None, server_url: str
) -> None:
async with sse_client(server_url + "/mounted_app/sse") as streams: async with sse_client(server_url + "/mounted_app/sse") as streams:
async with ClientSession(*streams) as session: async with ClientSession(*streams) as session:
# Test initialization # Test initialization
@@ -372,12 +343,8 @@ def run_context_server(server_port: int) -> None:
context_server = RequestContextServer() context_server = RequestContextServer()
async def handle_sse(request: Request) -> Response: async def handle_sse(request: Request) -> Response:
async with sse.connect_sse( async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
request.scope, request.receive, request._send await context_server.run(streams[0], streams[1], context_server.create_initialization_options())
) as streams:
await context_server.run(
streams[0], streams[1], context_server.create_initialization_options()
)
return Response() return Response()
app = Starlette( app = Starlette(
@@ -387,11 +354,7 @@ def run_context_server(server_port: int) -> None:
] ]
) )
server = uvicorn.Server( server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"starting context server on {server_port}") print(f"starting context server on {server_port}")
server.run() server.run()
@@ -399,9 +362,7 @@ def run_context_server(server_port: int) -> None:
@pytest.fixture() @pytest.fixture()
def context_server(server_port: int) -> Generator[None, None, None]: def context_server(server_port: int) -> Generator[None, None, None]:
"""Fixture that provides a server with request context capture""" """Fixture that provides a server with request context capture"""
proc = multiprocessing.Process( proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True)
target=run_context_server, kwargs={"server_port": server_port}, daemon=True
)
print("starting context server process") print("starting context server process")
proc.start() proc.start()
@@ -418,9 +379,7 @@ def context_server(server_port: int) -> Generator[None, None, None]:
time.sleep(0.1) time.sleep(0.1)
attempt += 1 attempt += 1
else: else:
raise RuntimeError( raise RuntimeError(f"Context server failed to start after {max_attempts} attempts")
f"Context server failed to start after {max_attempts} attempts"
)
yield yield
@@ -432,9 +391,7 @@ def context_server(server_port: int) -> Generator[None, None, None]:
@pytest.mark.anyio @pytest.mark.anyio
async def test_request_context_propagation( async def test_request_context_propagation(context_server: None, server_url: str) -> None:
context_server: None, server_url: str
) -> None:
"""Test that request context is properly propagated through SSE transport.""" """Test that request context is properly propagated through SSE transport."""
# Test with custom headers # Test with custom headers
custom_headers = { custom_headers = {
@@ -458,11 +415,7 @@ async def test_request_context_propagation(
# Parse the JSON response # Parse the JSON response
assert len(tool_result.content) == 1 assert len(tool_result.content) == 1
headers_data = json.loads( headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}")
tool_result.content[0].text
if tool_result.content[0].type == "text"
else "{}"
)
# Verify headers were propagated # Verify headers were propagated
assert headers_data.get("authorization") == "Bearer test-token" assert headers_data.get("authorization") == "Bearer test-token"
@@ -487,15 +440,11 @@ async def test_request_context_isolation(context_server: None, server_url: str)
await session.initialize() await session.initialize()
# Call the tool that echoes context # Call the tool that echoes context
tool_result = await session.call_tool( tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"})
"echo_context", {"request_id": f"request-{i}"}
)
assert len(tool_result.content) == 1 assert len(tool_result.content) == 1
context_data = json.loads( context_data = json.loads(
tool_result.content[0].text tool_result.content[0].text if tool_result.content[0].type == "text" else "{}"
if tool_result.content[0].type == "text"
else "{}"
) )
contexts.append(context_data) contexts.append(context_data)
@@ -514,8 +463,4 @@ def test_sse_message_id_coercion():
""" """
json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}' json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}'
msg = types.JSONRPCMessage.model_validate_json(json_message) msg = types.JSONRPCMessage.model_validate_json(json_message)
assert msg == snapshot( assert msg == snapshot(types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123)))
types.JSONRPCMessage(
root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123)
)
)

View File

@@ -72,9 +72,7 @@ class SimpleEventStore(EventStore):
self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = [] self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = []
self._event_id_counter = 0 self._event_id_counter = 0
async def store_event( async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage) -> EventId:
self, stream_id: StreamId, message: types.JSONRPCMessage
) -> EventId:
"""Store an event and return its ID.""" """Store an event and return its ID."""
self._event_id_counter += 1 self._event_id_counter += 1
event_id = str(self._event_id_counter) event_id = str(self._event_id_counter)
@@ -156,9 +154,7 @@ class ServerTest(Server):
# When the tool is called, send a notification to test GET stream # When the tool is called, send a notification to test GET stream
if name == "test_tool_with_standalone_notification": if name == "test_tool_with_standalone_notification":
await ctx.session.send_resource_updated( await ctx.session.send_resource_updated(uri=AnyUrl("http://test_resource"))
uri=AnyUrl("http://test_resource")
)
return [TextContent(type="text", text=f"Called {name}")] return [TextContent(type="text", text=f"Called {name}")]
elif name == "long_running_with_checkpoints": elif name == "long_running_with_checkpoints":
@@ -189,9 +185,7 @@ class ServerTest(Server):
messages=[ messages=[
types.SamplingMessage( types.SamplingMessage(
role="user", role="user",
content=types.TextContent( content=types.TextContent(type="text", text="Server needs client sampling"),
type="text", text="Server needs client sampling"
),
) )
], ],
max_tokens=100, max_tokens=100,
@@ -199,11 +193,7 @@ class ServerTest(Server):
) )
# Return the sampling result in the tool response # Return the sampling result in the tool response
response = ( response = sampling_result.content.text if sampling_result.content.type == "text" else None
sampling_result.content.text
if sampling_result.content.type == "text"
else None
)
return [ return [
TextContent( TextContent(
type="text", type="text",
@@ -214,9 +204,7 @@ class ServerTest(Server):
return [TextContent(type="text", text=f"Called {name}")] return [TextContent(type="text", text=f"Called {name}")]
def create_app( def create_app(is_json_response_enabled=False, event_store: EventStore | None = None) -> Starlette:
is_json_response_enabled=False, event_store: EventStore | None = None
) -> Starlette:
"""Create a Starlette application for testing using the session manager. """Create a Starlette application for testing using the session manager.
Args: Args:
@@ -245,9 +233,7 @@ def create_app(
return app return app
def run_server( def run_server(port: int, is_json_response_enabled=False, event_store: EventStore | None = None) -> None:
port: int, is_json_response_enabled=False, event_store: EventStore | None = None
) -> None:
"""Run the test server. """Run the test server.
Args: Args:
@@ -300,9 +286,7 @@ def json_server_port() -> int:
@pytest.fixture @pytest.fixture
def basic_server(basic_server_port: int) -> Generator[None, None, None]: def basic_server(basic_server_port: int) -> Generator[None, None, None]:
"""Start a basic server.""" """Start a basic server."""
proc = multiprocessing.Process( proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True)
target=run_server, kwargs={"port": basic_server_port}, daemon=True
)
proc.start() proc.start()
# Wait for server to be running # Wait for server to be running
@@ -778,9 +762,7 @@ async def test_streamablehttp_client_basic_connection(basic_server, basic_server
@pytest.mark.anyio @pytest.mark.anyio
async def test_streamablehttp_client_resource_read(initialized_client_session): async def test_streamablehttp_client_resource_read(initialized_client_session):
"""Test client resource read functionality.""" """Test client resource read functionality."""
response = await initialized_client_session.read_resource( response = await initialized_client_session.read_resource(uri=AnyUrl("foobar://test-resource"))
uri=AnyUrl("foobar://test-resource")
)
assert len(response.contents) == 1 assert len(response.contents) == 1
assert response.contents[0].uri == AnyUrl("foobar://test-resource") assert response.contents[0].uri == AnyUrl("foobar://test-resource")
assert response.contents[0].text == "Read test-resource" assert response.contents[0].text == "Read test-resource"
@@ -805,17 +787,13 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
async def test_streamablehttp_client_error_handling(initialized_client_session): async def test_streamablehttp_client_error_handling(initialized_client_session):
"""Test error handling in client.""" """Test error handling in client."""
with pytest.raises(McpError) as exc_info: with pytest.raises(McpError) as exc_info:
await initialized_client_session.read_resource( await initialized_client_session.read_resource(uri=AnyUrl("unknown://test-error"))
uri=AnyUrl("unknown://test-error")
)
assert exc_info.value.error.code == 0 assert exc_info.value.error.code == 0
assert "Unknown resource: unknown://test-error" in exc_info.value.error.message assert "Unknown resource: unknown://test-error" in exc_info.value.error.message
@pytest.mark.anyio @pytest.mark.anyio
async def test_streamablehttp_client_session_persistence( async def test_streamablehttp_client_session_persistence(basic_server, basic_server_url):
basic_server, basic_server_url
):
"""Test that session ID persists across requests.""" """Test that session ID persists across requests."""
async with streamablehttp_client(f"{basic_server_url}/mcp") as ( async with streamablehttp_client(f"{basic_server_url}/mcp") as (
read_stream, read_stream,
@@ -843,9 +821,7 @@ async def test_streamablehttp_client_session_persistence(
@pytest.mark.anyio @pytest.mark.anyio
async def test_streamablehttp_client_json_response( async def test_streamablehttp_client_json_response(json_response_server, json_server_url):
json_response_server, json_server_url
):
"""Test client with JSON response mode.""" """Test client with JSON response mode."""
async with streamablehttp_client(f"{json_server_url}/mcp") as ( async with streamablehttp_client(f"{json_server_url}/mcp") as (
read_stream, read_stream,
@@ -882,9 +858,7 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
# Define message handler to capture notifications # Define message handler to capture notifications
async def message_handler( async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
| types.ServerNotification
| Exception,
) -> None: ) -> None:
if isinstance(message, types.ServerNotification): if isinstance(message, types.ServerNotification):
notifications_received.append(message) notifications_received.append(message)
@@ -894,9 +868,7 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
write_stream, write_stream,
_, _,
): ):
async with ClientSession( async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
read_stream, write_stream, message_handler=message_handler
) as session:
# Initialize the session - this triggers the GET stream setup # Initialize the session - this triggers the GET stream setup
result = await session.initialize() result = await session.initialize()
assert isinstance(result, InitializeResult) assert isinstance(result, InitializeResult)
@@ -914,15 +886,11 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
assert str(notif.root.params.uri) == "http://test_resource/" assert str(notif.root.params.uri) == "http://test_resource/"
resource_update_found = True resource_update_found = True
assert ( assert resource_update_found, "ResourceUpdatedNotification not received via GET stream"
resource_update_found
), "ResourceUpdatedNotification not received via GET stream"
@pytest.mark.anyio @pytest.mark.anyio
async def test_streamablehttp_client_session_termination( async def test_streamablehttp_client_session_termination(basic_server, basic_server_url):
basic_server, basic_server_url
):
"""Test client session termination functionality.""" """Test client session termination functionality."""
captured_session_id = None captured_session_id = None
@@ -963,9 +931,7 @@ async def test_streamablehttp_client_session_termination(
@pytest.mark.anyio @pytest.mark.anyio
async def test_streamablehttp_client_session_termination_204( async def test_streamablehttp_client_session_termination_204(basic_server, basic_server_url, monkeypatch):
basic_server, basic_server_url, monkeypatch
):
"""Test client session termination functionality with a 204 response. """Test client session termination functionality with a 204 response.
This test patches the httpx client to return a 204 response for DELETEs. This test patches the httpx client to return a 204 response for DELETEs.
@@ -1040,9 +1006,7 @@ async def test_streamablehttp_client_resumption(event_server):
tool_started = False tool_started = False
async def message_handler( async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
| types.ServerNotification
| Exception,
) -> None: ) -> None:
if isinstance(message, types.ServerNotification): if isinstance(message, types.ServerNotification):
captured_notifications.append(message) captured_notifications.append(message)
@@ -1062,9 +1026,7 @@ async def test_streamablehttp_client_resumption(event_server):
write_stream, write_stream,
get_session_id, get_session_id,
): ):
async with ClientSession( async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
read_stream, write_stream, message_handler=message_handler
) as session:
# Initialize the session # Initialize the session
result = await session.initialize() result = await session.initialize()
assert isinstance(result, InitializeResult) assert isinstance(result, InitializeResult)
@@ -1082,9 +1044,7 @@ async def test_streamablehttp_client_resumption(event_server):
types.ClientRequest( types.ClientRequest(
types.CallToolRequest( types.CallToolRequest(
method="tools/call", method="tools/call",
params=types.CallToolRequestParams( params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
name="long_running_with_checkpoints", arguments={}
),
) )
), ),
types.CallToolResult, types.CallToolResult,
@@ -1114,9 +1074,7 @@ async def test_streamablehttp_client_resumption(event_server):
write_stream, write_stream,
_, _,
): ):
async with ClientSession( async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
read_stream, write_stream, message_handler=message_handler
) as session:
# Don't initialize - just use the existing session # Don't initialize - just use the existing session
# Resume the tool with the resumption token # Resume the tool with the resumption token
@@ -1129,9 +1087,7 @@ async def test_streamablehttp_client_resumption(event_server):
types.ClientRequest( types.ClientRequest(
types.CallToolRequest( types.CallToolRequest(
method="tools/call", method="tools/call",
params=types.CallToolRequestParams( params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
name="long_running_with_checkpoints", arguments={}
),
) )
), ),
types.CallToolResult, types.CallToolResult,
@@ -1149,14 +1105,11 @@ async def test_streamablehttp_client_resumption(event_server):
# Should not have the first notification # Should not have the first notification
# Check that "Tool started" notification isn't repeated when resuming # Check that "Tool started" notification isn't repeated when resuming
assert not any( assert not any(
isinstance(n.root, types.LoggingMessageNotification) isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started"
and n.root.params.data == "Tool started"
for n in captured_notifications for n in captured_notifications
) )
# there is no intersection between pre and post notifications # there is no intersection between pre and post notifications
assert not any( assert not any(n in captured_notifications_pre for n in captured_notifications)
n in captured_notifications_pre for n in captured_notifications
)
@pytest.mark.anyio @pytest.mark.anyio
@@ -1175,11 +1128,7 @@ async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
nonlocal sampling_callback_invoked, captured_message_params nonlocal sampling_callback_invoked, captured_message_params
sampling_callback_invoked = True sampling_callback_invoked = True
captured_message_params = params captured_message_params = params
message_received = ( message_received = params.messages[0].content.text if params.messages[0].content.type == "text" else None
params.messages[0].content.text
if params.messages[0].content.type == "text"
else None
)
return types.CreateMessageResult( return types.CreateMessageResult(
role="assistant", role="assistant",
@@ -1212,19 +1161,13 @@ async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
# Verify the tool result contains the expected content # Verify the tool result contains the expected content
assert len(tool_result.content) == 1 assert len(tool_result.content) == 1
assert tool_result.content[0].type == "text" assert tool_result.content[0].type == "text"
assert ( assert "Response from sampling: Received message from server" in tool_result.content[0].text
"Response from sampling: Received message from server"
in tool_result.content[0].text
)
# Verify sampling callback was invoked # Verify sampling callback was invoked
assert sampling_callback_invoked assert sampling_callback_invoked
assert captured_message_params is not None assert captured_message_params is not None
assert len(captured_message_params.messages) == 1 assert len(captured_message_params.messages) == 1
assert ( assert captured_message_params.messages[0].content.text == "Server needs client sampling"
captured_message_params.messages[0].content.text
== "Server needs client sampling"
)
# Context-aware server implementation for testing request context propagation # Context-aware server implementation for testing request context propagation
@@ -1325,9 +1268,7 @@ def run_context_aware_server(port: int):
@pytest.fixture @pytest.fixture
def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
"""Start the context-aware server in a separate process.""" """Start the context-aware server in a separate process."""
proc = multiprocessing.Process( proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True)
target=run_context_aware_server, args=(basic_server_port,), daemon=True
)
proc.start() proc.start()
# Wait for server to be running # Wait for server to be running
@@ -1342,9 +1283,7 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
time.sleep(0.1) time.sleep(0.1)
attempt += 1 attempt += 1
else: else:
raise RuntimeError( raise RuntimeError(f"Context-aware server failed to start after {max_attempts} attempts")
f"Context-aware server failed to start after {max_attempts} attempts"
)
yield yield
@@ -1355,9 +1294,7 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
@pytest.mark.anyio @pytest.mark.anyio
async def test_streamablehttp_request_context_propagation( async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None:
context_aware_server: None, basic_server_url: str
) -> None:
"""Test that request context is properly propagated through StreamableHTTP.""" """Test that request context is properly propagated through StreamableHTTP."""
custom_headers = { custom_headers = {
"Authorization": "Bearer test-token", "Authorization": "Bearer test-token",
@@ -1365,9 +1302,11 @@ async def test_streamablehttp_request_context_propagation(
"X-Trace-Id": "trace-123", "X-Trace-Id": "trace-123",
} }
async with streamablehttp_client( async with streamablehttp_client(f"{basic_server_url}/mcp", headers=custom_headers) as (
f"{basic_server_url}/mcp", headers=custom_headers read_stream,
) as (read_stream, write_stream, _): write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session: async with ClientSession(read_stream, write_stream) as session:
result = await session.initialize() result = await session.initialize()
assert isinstance(result, InitializeResult) assert isinstance(result, InitializeResult)
@@ -1388,9 +1327,7 @@ async def test_streamablehttp_request_context_propagation(
@pytest.mark.anyio @pytest.mark.anyio
async def test_streamablehttp_request_context_isolation( async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None:
context_aware_server: None, basic_server_url: str
) -> None:
"""Test that request contexts are isolated between StreamableHTTP clients.""" """Test that request contexts are isolated between StreamableHTTP clients."""
contexts = [] contexts = []
@@ -1402,16 +1339,12 @@ async def test_streamablehttp_request_context_isolation(
"Authorization": f"Bearer token-{i}", "Authorization": f"Bearer token-{i}",
} }
async with streamablehttp_client( async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as (read_stream, write_stream, _):
f"{basic_server_url}/mcp", headers=headers
) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session: async with ClientSession(read_stream, write_stream) as session:
await session.initialize() await session.initialize()
# Call the tool that echoes context # Call the tool that echoes context
tool_result = await session.call_tool( tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"})
"echo_context", {"request_id": f"request-{i}"}
)
assert len(tool_result.content) == 1 assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent) assert isinstance(tool_result.content[0], TextContent)

View File

@@ -54,11 +54,7 @@ class ServerTest(Server):
await anyio.sleep(2.0) await anyio.sleep(2.0)
return f"Slow response from {uri.host}" return f"Slow response from {uri.host}"
raise McpError( raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found"))
error=ErrorData(
code=404, message="OOPS! no resource with that URI was found"
)
)
@self.list_tools() @self.list_tools()
async def handle_list_tools() -> list[Tool]: async def handle_list_tools() -> list[Tool]:
@@ -81,12 +77,8 @@ def make_server_app() -> Starlette:
server = ServerTest() server = ServerTest()
async def handle_ws(websocket): async def handle_ws(websocket):
async with websocket_server( async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams:
websocket.scope, websocket.receive, websocket.send await server.run(streams[0], streams[1], server.create_initialization_options())
) as streams:
await server.run(
streams[0], streams[1], server.create_initialization_options()
)
app = Starlette( app = Starlette(
routes=[ routes=[
@@ -99,11 +91,7 @@ def make_server_app() -> Starlette:
def run_server(server_port: int) -> None: def run_server(server_port: int) -> None:
app = make_server_app() app = make_server_app()
server = uvicorn.Server( server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
config=uvicorn.Config(
app=app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"starting server on {server_port}") print(f"starting server on {server_port}")
server.run() server.run()
@@ -115,9 +103,7 @@ def run_server(server_port: int) -> None:
@pytest.fixture() @pytest.fixture()
def server(server_port: int) -> Generator[None, None, None]: def server(server_port: int) -> Generator[None, None, None]:
proc = multiprocessing.Process( proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True)
target=run_server, kwargs={"server_port": server_port}, daemon=True
)
print("starting process") print("starting process")
proc.start() proc.start()
@@ -147,9 +133,7 @@ def server(server_port: int) -> Generator[None, None, None]:
@pytest.fixture() @pytest.fixture()
async def initialized_ws_client_session( async def initialized_ws_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]:
server, server_url: str
) -> AsyncGenerator[ClientSession, None]:
"""Create and initialize a WebSocket client session""" """Create and initialize a WebSocket client session"""
async with websocket_client(server_url + "/ws") as streams: async with websocket_client(server_url + "/ws") as streams:
async with ClientSession(*streams) as session: async with ClientSession(*streams) as session:
@@ -186,9 +170,7 @@ async def test_ws_client_happy_request_and_response(
initialized_ws_client_session: ClientSession, initialized_ws_client_session: ClientSession,
) -> None: ) -> None:
"""Test a successful request and response via WebSocket""" """Test a successful request and response via WebSocket"""
result = await initialized_ws_client_session.read_resource( result = await initialized_ws_client_session.read_resource(AnyUrl("foobar://example"))
AnyUrl("foobar://example")
)
assert isinstance(result, ReadResourceResult) assert isinstance(result, ReadResourceResult)
assert isinstance(result.contents, list) assert isinstance(result.contents, list)
assert len(result.contents) > 0 assert len(result.contents) > 0
@@ -218,9 +200,7 @@ async def test_ws_client_timeout(
# Now test that we can still use the session after a timeout # Now test that we can still use the session after a timeout
with anyio.fail_after(5): # Longer timeout to allow completion with anyio.fail_after(5): # Longer timeout to allow completion
result = await initialized_ws_client_session.read_resource( result = await initialized_ws_client_session.read_resource(AnyUrl("foobar://example"))
AnyUrl("foobar://example")
)
assert isinstance(result, ReadResourceResult) assert isinstance(result, ReadResourceResult)
assert isinstance(result.contents, list) assert isinstance(result.contents, list)
assert len(result.contents) > 0 assert len(result.contents) > 0

View File

@@ -31,9 +31,7 @@ async def test_complex_inputs():
async with client_session(mcp._mcp_server) as client: async with client_session(mcp._mcp_server) as client:
tank = {"shrimp": [{"name": "bob"}, {"name": "alice"}]} tank = {"shrimp": [{"name": "bob"}, {"name": "alice"}]}
result = await client.call_tool( result = await client.call_tool("name_shrimp", {"tank": tank, "extra_names": ["charlie"]})
"name_shrimp", {"tank": tank, "extra_names": ["charlie"]}
)
assert len(result.content) == 3 assert len(result.content) == 3
assert isinstance(result.content[0], TextContent) assert isinstance(result.content[0], TextContent)
assert isinstance(result.content[1], TextContent) assert isinstance(result.content[1], TextContent)
@@ -86,9 +84,7 @@ async def test_desktop(monkeypatch):
def test_docs_examples(example: CodeExample, eval_example: EvalExample): def test_docs_examples(example: CodeExample, eval_example: EvalExample):
ruff_ignore: list[str] = ["F841", "I001"] ruff_ignore: list[str] = ["F841", "I001"]
eval_example.set_config( eval_example.set_config(ruff_ignore=ruff_ignore, target_version="py310", line_length=88)
ruff_ignore=ruff_ignore, target_version="py310", line_length=88
)
if eval_example.update_examples: # pragma: no cover if eval_example.update_examples: # pragma: no cover
eval_example.format(example) eval_example.format(example)