mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
Use 120 characters instead of 88 (#856)
This commit is contained in:
committed by
GitHub
parent
f7265f7b91
commit
543961968c
@@ -47,18 +47,14 @@ mcp = FastMCP(
|
|||||||
|
|
||||||
DB_DSN = "postgresql://postgres:postgres@localhost:54320/memory_db"
|
DB_DSN = "postgresql://postgres:postgres@localhost:54320/memory_db"
|
||||||
# reset memory with rm ~/.fastmcp/{USER}/memory/*
|
# reset memory with rm ~/.fastmcp/{USER}/memory/*
|
||||||
PROFILE_DIR = (
|
PROFILE_DIR = (Path.home() / ".fastmcp" / os.environ.get("USER", "anon") / "memory").resolve()
|
||||||
Path.home() / ".fastmcp" / os.environ.get("USER", "anon") / "memory"
|
|
||||||
).resolve()
|
|
||||||
PROFILE_DIR.mkdir(parents=True, exist_ok=True)
|
PROFILE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||||
a_array = np.array(a, dtype=np.float64)
|
a_array = np.array(a, dtype=np.float64)
|
||||||
b_array = np.array(b, dtype=np.float64)
|
b_array = np.array(b, dtype=np.float64)
|
||||||
return np.dot(a_array, b_array) / (
|
return np.dot(a_array, b_array) / (np.linalg.norm(a_array) * np.linalg.norm(b_array))
|
||||||
np.linalg.norm(a_array) * np.linalg.norm(b_array)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def do_ai[T](
|
async def do_ai[T](
|
||||||
@@ -97,9 +93,7 @@ class MemoryNode(BaseModel):
|
|||||||
summary: str = ""
|
summary: str = ""
|
||||||
importance: float = 1.0
|
importance: float = 1.0
|
||||||
access_count: int = 0
|
access_count: int = 0
|
||||||
timestamp: float = Field(
|
timestamp: float = Field(default_factory=lambda: datetime.now(timezone.utc).timestamp())
|
||||||
default_factory=lambda: datetime.now(timezone.utc).timestamp()
|
|
||||||
)
|
|
||||||
embedding: list[float]
|
embedding: list[float]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -152,9 +146,7 @@ class MemoryNode(BaseModel):
|
|||||||
self.importance += other.importance
|
self.importance += other.importance
|
||||||
self.access_count += other.access_count
|
self.access_count += other.access_count
|
||||||
self.embedding = [(a + b) / 2 for a, b in zip(self.embedding, other.embedding)]
|
self.embedding = [(a + b) / 2 for a, b in zip(self.embedding, other.embedding)]
|
||||||
self.summary = await do_ai(
|
self.summary = await do_ai(self.content, "Summarize the following text concisely.", str, deps)
|
||||||
self.content, "Summarize the following text concisely.", str, deps
|
|
||||||
)
|
|
||||||
await self.save(deps)
|
await self.save(deps)
|
||||||
# Delete the merged node from the database
|
# Delete the merged node from the database
|
||||||
if other.id is not None:
|
if other.id is not None:
|
||||||
@@ -221,9 +213,7 @@ async def find_similar_memories(embedding: list[float], deps: Deps) -> list[Memo
|
|||||||
|
|
||||||
async def update_importance(user_embedding: list[float], deps: Deps):
|
async def update_importance(user_embedding: list[float], deps: Deps):
|
||||||
async with deps.pool.acquire() as conn:
|
async with deps.pool.acquire() as conn:
|
||||||
rows = await conn.fetch(
|
rows = await conn.fetch("SELECT id, importance, access_count, embedding FROM memories")
|
||||||
"SELECT id, importance, access_count, embedding FROM memories"
|
|
||||||
)
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
memory_embedding = row["embedding"]
|
memory_embedding = row["embedding"]
|
||||||
similarity = cosine_similarity(user_embedding, memory_embedding)
|
similarity = cosine_similarity(user_embedding, memory_embedding)
|
||||||
@@ -273,9 +263,7 @@ async def display_memory_tree(deps: Deps) -> str:
|
|||||||
)
|
)
|
||||||
result = ""
|
result = ""
|
||||||
for row in rows:
|
for row in rows:
|
||||||
effective_importance = row["importance"] * (
|
effective_importance = row["importance"] * (1 + math.log(row["access_count"] + 1))
|
||||||
1 + math.log(row["access_count"] + 1)
|
|
||||||
)
|
|
||||||
summary = row["summary"] or row["content"]
|
summary = row["summary"] or row["content"]
|
||||||
result += f"- {summary} (Importance: {effective_importance:.2f})\n"
|
result += f"- {summary} (Importance: {effective_importance:.2f})\n"
|
||||||
return result
|
return result
|
||||||
@@ -283,15 +271,11 @@ async def display_memory_tree(deps: Deps) -> str:
|
|||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def remember(
|
async def remember(
|
||||||
contents: list[str] = Field(
|
contents: list[str] = Field(description="List of observations or memories to store"),
|
||||||
description="List of observations or memories to store"
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
deps = Deps(openai=AsyncOpenAI(), pool=await get_db_pool())
|
deps = Deps(openai=AsyncOpenAI(), pool=await get_db_pool())
|
||||||
try:
|
try:
|
||||||
return "\n".join(
|
return "\n".join(await asyncio.gather(*[add_memory(content, deps) for content in contents]))
|
||||||
await asyncio.gather(*[add_memory(content, deps) for content in contents])
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
await deps.pool.close()
|
await deps.pool.close()
|
||||||
|
|
||||||
@@ -305,9 +289,7 @@ async def read_profile() -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def initialize_database():
|
async def initialize_database():
|
||||||
pool = await asyncpg.create_pool(
|
pool = await asyncpg.create_pool("postgresql://postgres:postgres@localhost:54320/postgres")
|
||||||
"postgresql://postgres:postgres@localhost:54320/postgres"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
await conn.execute("""
|
await conn.execute("""
|
||||||
|
|||||||
@@ -28,15 +28,11 @@ from mcp.server.fastmcp import FastMCP
|
|||||||
|
|
||||||
|
|
||||||
class SurgeSettings(BaseSettings):
|
class SurgeSettings(BaseSettings):
|
||||||
model_config: SettingsConfigDict = SettingsConfigDict(
|
model_config: SettingsConfigDict = SettingsConfigDict(env_prefix="SURGE_", env_file=".env")
|
||||||
env_prefix="SURGE_", env_file=".env"
|
|
||||||
)
|
|
||||||
|
|
||||||
api_key: str
|
api_key: str
|
||||||
account_id: str
|
account_id: str
|
||||||
my_phone_number: Annotated[
|
my_phone_number: Annotated[str, BeforeValidator(lambda v: "+" + v if not v.startswith("+") else v)]
|
||||||
str, BeforeValidator(lambda v: "+" + v if not v.startswith("+") else v)
|
|
||||||
]
|
|
||||||
my_first_name: str
|
my_first_name: str
|
||||||
my_last_name: str
|
my_last_name: str
|
||||||
|
|
||||||
|
|||||||
@@ -8,10 +8,7 @@ from mcp.server.fastmcp import FastMCP
|
|||||||
mcp = FastMCP()
|
mcp = FastMCP()
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool(
|
@mcp.tool(description="🌟 A tool that uses various Unicode characters in its description: " "á é í ó ú ñ 漢字 🎉")
|
||||||
description="🌟 A tool that uses various Unicode characters in its description: "
|
|
||||||
"á é í ó ú ñ 漢字 🎉"
|
|
||||||
)
|
|
||||||
def hello_unicode(name: str = "世界", greeting: str = "¡Hola") -> str:
|
def hello_unicode(name: str = "世界", greeting: str = "¡Hola") -> str:
|
||||||
"""
|
"""
|
||||||
A simple tool that demonstrates Unicode handling in:
|
A simple tool that demonstrates Unicode handling in:
|
||||||
|
|||||||
@@ -82,9 +82,7 @@ class SimpleGitHubOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
"""Register a new OAuth client."""
|
"""Register a new OAuth client."""
|
||||||
self.clients[client_info.client_id] = client_info
|
self.clients[client_info.client_id] = client_info
|
||||||
|
|
||||||
async def authorize(
|
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
|
||||||
self, client: OAuthClientInformationFull, params: AuthorizationParams
|
|
||||||
) -> str:
|
|
||||||
"""Generate an authorization URL for GitHub OAuth flow."""
|
"""Generate an authorization URL for GitHub OAuth flow."""
|
||||||
state = params.state or secrets.token_hex(16)
|
state = params.state or secrets.token_hex(16)
|
||||||
|
|
||||||
@@ -92,9 +90,7 @@ class SimpleGitHubOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
self.state_mapping[state] = {
|
self.state_mapping[state] = {
|
||||||
"redirect_uri": str(params.redirect_uri),
|
"redirect_uri": str(params.redirect_uri),
|
||||||
"code_challenge": params.code_challenge,
|
"code_challenge": params.code_challenge,
|
||||||
"redirect_uri_provided_explicitly": str(
|
"redirect_uri_provided_explicitly": str(params.redirect_uri_provided_explicitly),
|
||||||
params.redirect_uri_provided_explicitly
|
|
||||||
),
|
|
||||||
"client_id": client.client_id,
|
"client_id": client.client_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -117,9 +113,7 @@ class SimpleGitHubOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
|
|
||||||
redirect_uri = state_data["redirect_uri"]
|
redirect_uri = state_data["redirect_uri"]
|
||||||
code_challenge = state_data["code_challenge"]
|
code_challenge = state_data["code_challenge"]
|
||||||
redirect_uri_provided_explicitly = (
|
redirect_uri_provided_explicitly = state_data["redirect_uri_provided_explicitly"] == "True"
|
||||||
state_data["redirect_uri_provided_explicitly"] == "True"
|
|
||||||
)
|
|
||||||
client_id = state_data["client_id"]
|
client_id = state_data["client_id"]
|
||||||
|
|
||||||
# Exchange code for token with GitHub
|
# Exchange code for token with GitHub
|
||||||
@@ -200,8 +194,7 @@ class SimpleGitHubOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
for token, data in self.tokens.items()
|
for token, data in self.tokens.items()
|
||||||
# see https://github.blog/engineering/platform-security/behind-githubs-new-authentication-token-formats/
|
# see https://github.blog/engineering/platform-security/behind-githubs-new-authentication-token-formats/
|
||||||
# which you get depends on your GH app setup.
|
# which you get depends on your GH app setup.
|
||||||
if (token.startswith("ghu_") or token.startswith("gho_"))
|
if (token.startswith("ghu_") or token.startswith("gho_")) and data.client_id == client.client_id
|
||||||
and data.client_id == client.client_id
|
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -232,9 +225,7 @@ class SimpleGitHubOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
|
|
||||||
return access_token
|
return access_token
|
||||||
|
|
||||||
async def load_refresh_token(
|
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None:
|
||||||
self, client: OAuthClientInformationFull, refresh_token: str
|
|
||||||
) -> RefreshToken | None:
|
|
||||||
"""Load a refresh token - not supported."""
|
"""Load a refresh token - not supported."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -247,9 +238,7 @@ class SimpleGitHubOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
"""Exchange refresh token"""
|
"""Exchange refresh token"""
|
||||||
raise NotImplementedError("Not supported")
|
raise NotImplementedError("Not supported")
|
||||||
|
|
||||||
async def revoke_token(
|
async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None:
|
||||||
self, token: str, token_type_hint: str | None = None
|
|
||||||
) -> None:
|
|
||||||
"""Revoke a token."""
|
"""Revoke a token."""
|
||||||
if token in self.tokens:
|
if token in self.tokens:
|
||||||
del self.tokens[token]
|
del self.tokens[token]
|
||||||
@@ -335,9 +324,7 @@ def create_simple_mcp_server(settings: ServerSettings) -> FastMCP:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise ValueError(
|
raise ValueError(f"GitHub API error: {response.status_code} - {response.text}")
|
||||||
f"GitHub API error: {response.status_code} - {response.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@@ -361,9 +348,7 @@ def main(port: int, host: str, transport: Literal["sse", "streamable-http"]) ->
|
|||||||
# No hardcoded credentials - all from environment variables
|
# No hardcoded credentials - all from environment variables
|
||||||
settings = ServerSettings(host=host, port=port)
|
settings = ServerSettings(host=host, port=port)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(
|
logger.error("Failed to load settings. Make sure environment variables are set:")
|
||||||
"Failed to load settings. Make sure environment variables are set:"
|
|
||||||
)
|
|
||||||
logger.error(" MCP_GITHUB_GITHUB_CLIENT_ID=<your-client-id>")
|
logger.error(" MCP_GITHUB_GITHUB_CLIENT_ID=<your-client-id>")
|
||||||
logger.error(" MCP_GITHUB_GITHUB_CLIENT_SECRET=<your-client-secret>")
|
logger.error(" MCP_GITHUB_GITHUB_CLIENT_SECRET=<your-client-secret>")
|
||||||
logger.error(f"Error: {e}")
|
logger.error(f"Error: {e}")
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ select = ["C4", "E", "F", "I", "PERF", "UP"]
|
|||||||
ignore = ["PERF203"]
|
ignore = ["PERF203"]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 88
|
line-length = 120
|
||||||
target-version = "py310"
|
target-version = "py310"
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
|||||||
@@ -21,9 +21,7 @@ def get_claude_config_path() -> Path | None:
|
|||||||
elif sys.platform == "darwin":
|
elif sys.platform == "darwin":
|
||||||
path = Path(Path.home(), "Library", "Application Support", "Claude")
|
path = Path(Path.home(), "Library", "Application Support", "Claude")
|
||||||
elif sys.platform.startswith("linux"):
|
elif sys.platform.startswith("linux"):
|
||||||
path = Path(
|
path = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config"), "Claude")
|
||||||
os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config"), "Claude"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -37,8 +35,7 @@ def get_uv_path() -> str:
|
|||||||
uv_path = shutil.which("uv")
|
uv_path = shutil.which("uv")
|
||||||
if not uv_path:
|
if not uv_path:
|
||||||
logger.error(
|
logger.error(
|
||||||
"uv executable not found in PATH, falling back to 'uv'. "
|
"uv executable not found in PATH, falling back to 'uv'. " "Please ensure uv is installed and in your PATH"
|
||||||
"Please ensure uv is installed and in your PATH"
|
|
||||||
)
|
)
|
||||||
return "uv" # Fall back to just "uv" if not found
|
return "uv" # Fall back to just "uv" if not found
|
||||||
return uv_path
|
return uv_path
|
||||||
@@ -94,10 +91,7 @@ def update_claude_config(
|
|||||||
config["mcpServers"] = {}
|
config["mcpServers"] = {}
|
||||||
|
|
||||||
# Always preserve existing env vars and merge with new ones
|
# Always preserve existing env vars and merge with new ones
|
||||||
if (
|
if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]:
|
||||||
server_name in config["mcpServers"]
|
|
||||||
and "env" in config["mcpServers"][server_name]
|
|
||||||
):
|
|
||||||
existing_env = config["mcpServers"][server_name]["env"]
|
existing_env = config["mcpServers"][server_name]["env"]
|
||||||
if env_vars:
|
if env_vars:
|
||||||
# New vars take precedence over existing ones
|
# New vars take precedence over existing ones
|
||||||
|
|||||||
@@ -45,9 +45,7 @@ def _get_npx_command():
|
|||||||
# Try both npx.cmd and npx.exe on Windows
|
# Try both npx.cmd and npx.exe on Windows
|
||||||
for cmd in ["npx.cmd", "npx.exe", "npx"]:
|
for cmd in ["npx.cmd", "npx.exe", "npx"]:
|
||||||
try:
|
try:
|
||||||
subprocess.run(
|
subprocess.run([cmd, "--version"], check=True, capture_output=True, shell=True)
|
||||||
[cmd, "--version"], check=True, capture_output=True, shell=True
|
|
||||||
)
|
|
||||||
return cmd
|
return cmd
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
continue
|
continue
|
||||||
@@ -58,9 +56,7 @@ def _get_npx_command():
|
|||||||
def _parse_env_var(env_var: str) -> tuple[str, str]:
|
def _parse_env_var(env_var: str) -> tuple[str, str]:
|
||||||
"""Parse environment variable string in format KEY=VALUE."""
|
"""Parse environment variable string in format KEY=VALUE."""
|
||||||
if "=" not in env_var:
|
if "=" not in env_var:
|
||||||
logger.error(
|
logger.error(f"Invalid environment variable format: {env_var}. Must be KEY=VALUE")
|
||||||
f"Invalid environment variable format: {env_var}. Must be KEY=VALUE"
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
key, value = env_var.split("=", 1)
|
key, value = env_var.split("=", 1)
|
||||||
return key.strip(), value.strip()
|
return key.strip(), value.strip()
|
||||||
@@ -154,14 +150,10 @@ def _import_server(file: Path, server_object: str | None = None):
|
|||||||
True if it's supported.
|
True if it's supported.
|
||||||
"""
|
"""
|
||||||
if not isinstance(server_object, FastMCP):
|
if not isinstance(server_object, FastMCP):
|
||||||
logger.error(
|
logger.error(f"The server object {object_name} is of type " f"{type(server_object)} (expecting {FastMCP}).")
|
||||||
f"The server object {object_name} is of type "
|
|
||||||
f"{type(server_object)} (expecting {FastMCP})."
|
|
||||||
)
|
|
||||||
if isinstance(server_object, LowLevelServer):
|
if isinstance(server_object, LowLevelServer):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Note that only FastMCP server is supported. Low level "
|
"Note that only FastMCP server is supported. Low level " "Server class is not yet supported."
|
||||||
"Server class is not yet supported."
|
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@@ -172,10 +164,7 @@ def _import_server(file: Path, server_object: str | None = None):
|
|||||||
for name in ["mcp", "server", "app"]:
|
for name in ["mcp", "server", "app"]:
|
||||||
if hasattr(module, name):
|
if hasattr(module, name):
|
||||||
if not _check_server_object(getattr(module, name), f"{file}:{name}"):
|
if not _check_server_object(getattr(module, name), f"{file}:{name}"):
|
||||||
logger.error(
|
logger.error(f"Ignoring object '{file}:{name}' as it's not a valid " "server object")
|
||||||
f"Ignoring object '{file}:{name}' as it's not a valid "
|
|
||||||
"server object"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
return getattr(module, name)
|
return getattr(module, name)
|
||||||
|
|
||||||
@@ -280,8 +269,7 @@ def dev(
|
|||||||
npx_cmd = _get_npx_command()
|
npx_cmd = _get_npx_command()
|
||||||
if not npx_cmd:
|
if not npx_cmd:
|
||||||
logger.error(
|
logger.error(
|
||||||
"npx not found. Please ensure Node.js and npm are properly installed "
|
"npx not found. Please ensure Node.js and npm are properly installed " "and added to your system PATH."
|
||||||
"and added to your system PATH."
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@@ -383,8 +371,7 @@ def install(
|
|||||||
typer.Option(
|
typer.Option(
|
||||||
"--name",
|
"--name",
|
||||||
"-n",
|
"-n",
|
||||||
help="Custom name for the server (defaults to server's name attribute or"
|
help="Custom name for the server (defaults to server's name attribute or" " file name)",
|
||||||
" file name)",
|
|
||||||
),
|
),
|
||||||
] = None,
|
] = None,
|
||||||
with_editable: Annotated[
|
with_editable: Annotated[
|
||||||
@@ -458,8 +445,7 @@ def install(
|
|||||||
name = server.name
|
name = server.name
|
||||||
except (ImportError, ModuleNotFoundError) as e:
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Could not import server (likely missing dependencies), using file"
|
"Could not import server (likely missing dependencies), using file" " name",
|
||||||
" name",
|
|
||||||
extra={"error": str(e)},
|
extra={"error": str(e)},
|
||||||
)
|
)
|
||||||
name = file.stem
|
name = file.stem
|
||||||
@@ -477,11 +463,7 @@ def install(
|
|||||||
if env_file:
|
if env_file:
|
||||||
if dotenv:
|
if dotenv:
|
||||||
try:
|
try:
|
||||||
env_dict |= {
|
env_dict |= {k: v for k, v in dotenv.dotenv_values(env_file).items() if v is not None}
|
||||||
k: v
|
|
||||||
for k, v in dotenv.dotenv_values(env_file).items()
|
|
||||||
if v is not None
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load .env file: {e}")
|
logger.error(f"Failed to load .env file: {e}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|||||||
@@ -24,9 +24,7 @@ logger = logging.getLogger("client")
|
|||||||
|
|
||||||
|
|
||||||
async def message_handler(
|
async def message_handler(
|
||||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||||
| types.ServerNotification
|
|
||||||
| Exception,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(message, Exception):
|
if isinstance(message, Exception):
|
||||||
logger.error("Error: %s", message)
|
logger.error("Error: %s", message)
|
||||||
@@ -60,9 +58,7 @@ async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]])
|
|||||||
await run_session(*streams)
|
await run_session(*streams)
|
||||||
else:
|
else:
|
||||||
# Use stdio client for commands
|
# Use stdio client for commands
|
||||||
server_parameters = StdioServerParameters(
|
server_parameters = StdioServerParameters(command=command_or_url, args=args, env=env_dict)
|
||||||
command=command_or_url, args=args, env=env_dict
|
|
||||||
)
|
|
||||||
async with stdio_client(server_parameters) as streams:
|
async with stdio_client(server_parameters) as streams:
|
||||||
await run_session(*streams)
|
await run_session(*streams)
|
||||||
|
|
||||||
|
|||||||
@@ -17,12 +17,7 @@ from urllib.parse import urlencode, urljoin
|
|||||||
import anyio
|
import anyio
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from mcp.shared.auth import (
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata, OAuthToken
|
||||||
OAuthClientInformationFull,
|
|
||||||
OAuthClientMetadata,
|
|
||||||
OAuthMetadata,
|
|
||||||
OAuthToken,
|
|
||||||
)
|
|
||||||
from mcp.types import LATEST_PROTOCOL_VERSION
|
from mcp.types import LATEST_PROTOCOL_VERSION
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -100,10 +95,7 @@ class OAuthClientProvider(httpx.Auth):
|
|||||||
|
|
||||||
def _generate_code_verifier(self) -> str:
|
def _generate_code_verifier(self) -> str:
|
||||||
"""Generate a cryptographically random code verifier for PKCE."""
|
"""Generate a cryptographically random code verifier for PKCE."""
|
||||||
return "".join(
|
return "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128))
|
||||||
secrets.choice(string.ascii_letters + string.digits + "-._~")
|
|
||||||
for _ in range(128)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _generate_code_challenge(self, code_verifier: str) -> str:
|
def _generate_code_challenge(self, code_verifier: str) -> str:
|
||||||
"""Generate a code challenge from a code verifier using SHA256."""
|
"""Generate a code challenge from a code verifier using SHA256."""
|
||||||
@@ -148,9 +140,7 @@ class OAuthClientProvider(httpx.Auth):
|
|||||||
return None
|
return None
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
metadata_json = response.json()
|
metadata_json = response.json()
|
||||||
logger.debug(
|
logger.debug(f"OAuth metadata discovered (no MCP header): {metadata_json}")
|
||||||
f"OAuth metadata discovered (no MCP header): {metadata_json}"
|
|
||||||
)
|
|
||||||
return OAuthMetadata.model_validate(metadata_json)
|
return OAuthMetadata.model_validate(metadata_json)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to discover OAuth metadata")
|
logger.exception("Failed to discover OAuth metadata")
|
||||||
@@ -176,17 +166,11 @@ class OAuthClientProvider(httpx.Auth):
|
|||||||
registration_url = urljoin(auth_base_url, "/register")
|
registration_url = urljoin(auth_base_url, "/register")
|
||||||
|
|
||||||
# Handle default scope
|
# Handle default scope
|
||||||
if (
|
if client_metadata.scope is None and metadata and metadata.scopes_supported is not None:
|
||||||
client_metadata.scope is None
|
|
||||||
and metadata
|
|
||||||
and metadata.scopes_supported is not None
|
|
||||||
):
|
|
||||||
client_metadata.scope = " ".join(metadata.scopes_supported)
|
client_metadata.scope = " ".join(metadata.scopes_supported)
|
||||||
|
|
||||||
# Serialize client metadata
|
# Serialize client metadata
|
||||||
registration_data = client_metadata.model_dump(
|
registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
by_alias=True, mode="json", exclude_none=True
|
|
||||||
)
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
try:
|
try:
|
||||||
@@ -213,9 +197,7 @@ class OAuthClientProvider(httpx.Auth):
|
|||||||
logger.exception("Registration error")
|
logger.exception("Registration error")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def async_auth_flow(
|
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
||||||
self, request: httpx.Request
|
|
||||||
) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
|
||||||
"""
|
"""
|
||||||
HTTPX auth flow integration.
|
HTTPX auth flow integration.
|
||||||
"""
|
"""
|
||||||
@@ -225,9 +207,7 @@ class OAuthClientProvider(httpx.Auth):
|
|||||||
await self.ensure_token()
|
await self.ensure_token()
|
||||||
# Add Bearer token if available
|
# Add Bearer token if available
|
||||||
if self._current_tokens and self._current_tokens.access_token:
|
if self._current_tokens and self._current_tokens.access_token:
|
||||||
request.headers["Authorization"] = (
|
request.headers["Authorization"] = f"Bearer {self._current_tokens.access_token}"
|
||||||
f"Bearer {self._current_tokens.access_token}"
|
|
||||||
)
|
|
||||||
|
|
||||||
response = yield request
|
response = yield request
|
||||||
|
|
||||||
@@ -305,11 +285,7 @@ class OAuthClientProvider(httpx.Auth):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Try refreshing existing token
|
# Try refreshing existing token
|
||||||
if (
|
if self._current_tokens and self._current_tokens.refresh_token and await self._refresh_access_token():
|
||||||
self._current_tokens
|
|
||||||
and self._current_tokens.refresh_token
|
|
||||||
and await self._refresh_access_token()
|
|
||||||
):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Fall back to full OAuth flow
|
# Fall back to full OAuth flow
|
||||||
@@ -361,12 +337,8 @@ class OAuthClientProvider(httpx.Auth):
|
|||||||
auth_code, returned_state = await self.callback_handler()
|
auth_code, returned_state = await self.callback_handler()
|
||||||
|
|
||||||
# Validate state parameter for CSRF protection
|
# Validate state parameter for CSRF protection
|
||||||
if returned_state is None or not secrets.compare_digest(
|
if returned_state is None or not secrets.compare_digest(returned_state, self._auth_state):
|
||||||
returned_state, self._auth_state
|
raise Exception(f"State parameter mismatch: {returned_state} != {self._auth_state}")
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"State parameter mismatch: {returned_state} != {self._auth_state}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clear state after validation
|
# Clear state after validation
|
||||||
self._auth_state = None
|
self._auth_state = None
|
||||||
@@ -377,9 +349,7 @@ class OAuthClientProvider(httpx.Auth):
|
|||||||
# Exchange authorization code for tokens
|
# Exchange authorization code for tokens
|
||||||
await self._exchange_code_for_token(auth_code, client_info)
|
await self._exchange_code_for_token(auth_code, client_info)
|
||||||
|
|
||||||
async def _exchange_code_for_token(
|
async def _exchange_code_for_token(self, auth_code: str, client_info: OAuthClientInformationFull) -> None:
|
||||||
self, auth_code: str, client_info: OAuthClientInformationFull
|
|
||||||
) -> None:
|
|
||||||
"""Exchange authorization code for access token."""
|
"""Exchange authorization code for access token."""
|
||||||
# Get token endpoint
|
# Get token endpoint
|
||||||
if self._metadata and self._metadata.token_endpoint:
|
if self._metadata and self._metadata.token_endpoint:
|
||||||
@@ -412,17 +382,10 @@ class OAuthClientProvider(httpx.Auth):
|
|||||||
# Parse OAuth error response
|
# Parse OAuth error response
|
||||||
try:
|
try:
|
||||||
error_data = response.json()
|
error_data = response.json()
|
||||||
error_msg = error_data.get(
|
error_msg = error_data.get("error_description", error_data.get("error", "Unknown error"))
|
||||||
"error_description", error_data.get("error", "Unknown error")
|
raise Exception(f"Token exchange failed: {error_msg} " f"(HTTP {response.status_code})")
|
||||||
)
|
|
||||||
raise Exception(
|
|
||||||
f"Token exchange failed: {error_msg} "
|
|
||||||
f"(HTTP {response.status_code})"
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(f"Token exchange failed: {response.status_code} {response.text}")
|
||||||
f"Token exchange failed: {response.status_code} {response.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse token response
|
# Parse token response
|
||||||
token_response = OAuthToken.model_validate(response.json())
|
token_response = OAuthToken.model_validate(response.json())
|
||||||
|
|||||||
@@ -38,16 +38,12 @@ class LoggingFnT(Protocol):
|
|||||||
class MessageHandlerFnT(Protocol):
|
class MessageHandlerFnT(Protocol):
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self,
|
self,
|
||||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||||
| types.ServerNotification
|
|
||||||
| Exception,
|
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
async def _default_message_handler(
|
async def _default_message_handler(
|
||||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||||
| types.ServerNotification
|
|
||||||
| Exception,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
await anyio.lowlevel.checkpoint()
|
await anyio.lowlevel.checkpoint()
|
||||||
|
|
||||||
@@ -77,9 +73,7 @@ async def _default_logging_callback(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
|
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
|
||||||
types.ClientResult | types.ErrorData
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClientSession(
|
class ClientSession(
|
||||||
@@ -116,11 +110,7 @@ class ClientSession(
|
|||||||
self._message_handler = message_handler or _default_message_handler
|
self._message_handler = message_handler or _default_message_handler
|
||||||
|
|
||||||
async def initialize(self) -> types.InitializeResult:
|
async def initialize(self) -> types.InitializeResult:
|
||||||
sampling = (
|
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
|
||||||
types.SamplingCapability()
|
|
||||||
if self._sampling_callback is not _default_sampling_callback
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
roots = (
|
roots = (
|
||||||
# TODO: Should this be based on whether we
|
# TODO: Should this be based on whether we
|
||||||
# _will_ send notifications, or only whether
|
# _will_ send notifications, or only whether
|
||||||
@@ -149,15 +139,10 @@ class ClientSession(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
|
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
|
||||||
raise RuntimeError(
|
raise RuntimeError("Unsupported protocol version from the server: " f"{result.protocolVersion}")
|
||||||
"Unsupported protocol version from the server: "
|
|
||||||
f"{result.protocolVersion}"
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.send_notification(
|
await self.send_notification(
|
||||||
types.ClientNotification(
|
types.ClientNotification(types.InitializedNotification(method="notifications/initialized"))
|
||||||
types.InitializedNotification(method="notifications/initialized")
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -207,33 +192,25 @@ class ClientSession(
|
|||||||
types.EmptyResult,
|
types.EmptyResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_resources(
|
async def list_resources(self, cursor: str | None = None) -> types.ListResourcesResult:
|
||||||
self, cursor: str | None = None
|
|
||||||
) -> types.ListResourcesResult:
|
|
||||||
"""Send a resources/list request."""
|
"""Send a resources/list request."""
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
types.ClientRequest(
|
types.ClientRequest(
|
||||||
types.ListResourcesRequest(
|
types.ListResourcesRequest(
|
||||||
method="resources/list",
|
method="resources/list",
|
||||||
params=types.PaginatedRequestParams(cursor=cursor)
|
params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None,
|
||||||
if cursor is not None
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
types.ListResourcesResult,
|
types.ListResourcesResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_resource_templates(
|
async def list_resource_templates(self, cursor: str | None = None) -> types.ListResourceTemplatesResult:
|
||||||
self, cursor: str | None = None
|
|
||||||
) -> types.ListResourceTemplatesResult:
|
|
||||||
"""Send a resources/templates/list request."""
|
"""Send a resources/templates/list request."""
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
types.ClientRequest(
|
types.ClientRequest(
|
||||||
types.ListResourceTemplatesRequest(
|
types.ListResourceTemplatesRequest(
|
||||||
method="resources/templates/list",
|
method="resources/templates/list",
|
||||||
params=types.PaginatedRequestParams(cursor=cursor)
|
params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None,
|
||||||
if cursor is not None
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
types.ListResourceTemplatesResult,
|
types.ListResourceTemplatesResult,
|
||||||
@@ -305,17 +282,13 @@ class ClientSession(
|
|||||||
types.ClientRequest(
|
types.ClientRequest(
|
||||||
types.ListPromptsRequest(
|
types.ListPromptsRequest(
|
||||||
method="prompts/list",
|
method="prompts/list",
|
||||||
params=types.PaginatedRequestParams(cursor=cursor)
|
params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None,
|
||||||
if cursor is not None
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
types.ListPromptsResult,
|
types.ListPromptsResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_prompt(
|
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
|
||||||
self, name: str, arguments: dict[str, str] | None = None
|
|
||||||
) -> types.GetPromptResult:
|
|
||||||
"""Send a prompts/get request."""
|
"""Send a prompts/get request."""
|
||||||
return await self.send_request(
|
return await self.send_request(
|
||||||
types.ClientRequest(
|
types.ClientRequest(
|
||||||
@@ -352,9 +325,7 @@ class ClientSession(
|
|||||||
types.ClientRequest(
|
types.ClientRequest(
|
||||||
types.ListToolsRequest(
|
types.ListToolsRequest(
|
||||||
method="tools/list",
|
method="tools/list",
|
||||||
params=types.PaginatedRequestParams(cursor=cursor)
|
params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None,
|
||||||
if cursor is not None
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
types.ListToolsResult,
|
types.ListToolsResult,
|
||||||
@@ -370,9 +341,7 @@ class ClientSession(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _received_request(
|
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
|
||||||
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
|
|
||||||
) -> None:
|
|
||||||
ctx = RequestContext[ClientSession, Any](
|
ctx = RequestContext[ClientSession, Any](
|
||||||
request_id=responder.request_id,
|
request_id=responder.request_id,
|
||||||
meta=responder.request_meta,
|
meta=responder.request_meta,
|
||||||
@@ -395,22 +364,16 @@ class ClientSession(
|
|||||||
|
|
||||||
case types.PingRequest():
|
case types.PingRequest():
|
||||||
with responder:
|
with responder:
|
||||||
return await responder.respond(
|
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
|
||||||
types.ClientResult(root=types.EmptyResult())
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_incoming(
|
async def _handle_incoming(
|
||||||
self,
|
self,
|
||||||
req: RequestResponder[types.ServerRequest, types.ClientResult]
|
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||||
| types.ServerNotification
|
|
||||||
| Exception,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle incoming messages by forwarding to the message handler."""
|
"""Handle incoming messages by forwarding to the message handler."""
|
||||||
await self._message_handler(req)
|
await self._message_handler(req)
|
||||||
|
|
||||||
async def _received_notification(
|
async def _received_notification(self, notification: types.ServerNotification) -> None:
|
||||||
self, notification: types.ServerNotification
|
|
||||||
) -> None:
|
|
||||||
"""Handle notifications from the server."""
|
"""Handle notifications from the server."""
|
||||||
# Process specific notification types
|
# Process specific notification types
|
||||||
match notification.root:
|
match notification.root:
|
||||||
|
|||||||
@@ -62,9 +62,7 @@ class StreamableHttpParameters(BaseModel):
|
|||||||
terminate_on_close: bool = True
|
terminate_on_close: bool = True
|
||||||
|
|
||||||
|
|
||||||
ServerParameters: TypeAlias = (
|
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
|
||||||
StdioServerParameters | SseServerParameters | StreamableHttpParameters
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClientSessionGroup:
|
class ClientSessionGroup:
|
||||||
@@ -261,9 +259,7 @@ class ClientSessionGroup:
|
|||||||
)
|
)
|
||||||
read, write, _ = await session_stack.enter_async_context(client)
|
read, write, _ = await session_stack.enter_async_context(client)
|
||||||
|
|
||||||
session = await session_stack.enter_async_context(
|
session = await session_stack.enter_async_context(mcp.ClientSession(read, write))
|
||||||
mcp.ClientSession(read, write)
|
|
||||||
)
|
|
||||||
result = await session.initialize()
|
result = await session.initialize()
|
||||||
|
|
||||||
# Session successfully initialized.
|
# Session successfully initialized.
|
||||||
@@ -280,9 +276,7 @@ class ClientSessionGroup:
|
|||||||
await session_stack.aclose()
|
await session_stack.aclose()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _aggregate_components(
|
async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None:
|
||||||
self, server_info: types.Implementation, session: mcp.ClientSession
|
|
||||||
) -> None:
|
|
||||||
"""Aggregates prompts, resources, and tools from a given session."""
|
"""Aggregates prompts, resources, and tools from a given session."""
|
||||||
|
|
||||||
# Create a reverse index so we can find all prompts, resources, and
|
# Create a reverse index so we can find all prompts, resources, and
|
||||||
|
|||||||
@@ -73,20 +73,16 @@ async def sse_client(
|
|||||||
match sse.event:
|
match sse.event:
|
||||||
case "endpoint":
|
case "endpoint":
|
||||||
endpoint_url = urljoin(url, sse.data)
|
endpoint_url = urljoin(url, sse.data)
|
||||||
logger.debug(
|
logger.debug(f"Received endpoint URL: {endpoint_url}")
|
||||||
f"Received endpoint URL: {endpoint_url}"
|
|
||||||
)
|
|
||||||
|
|
||||||
url_parsed = urlparse(url)
|
url_parsed = urlparse(url)
|
||||||
endpoint_parsed = urlparse(endpoint_url)
|
endpoint_parsed = urlparse(endpoint_url)
|
||||||
if (
|
if (
|
||||||
url_parsed.netloc != endpoint_parsed.netloc
|
url_parsed.netloc != endpoint_parsed.netloc
|
||||||
or url_parsed.scheme
|
or url_parsed.scheme != endpoint_parsed.scheme
|
||||||
!= endpoint_parsed.scheme
|
|
||||||
):
|
):
|
||||||
error_msg = (
|
error_msg = (
|
||||||
"Endpoint origin does not match "
|
"Endpoint origin does not match " f"connection origin: {endpoint_url}"
|
||||||
f"connection origin: {endpoint_url}"
|
|
||||||
)
|
)
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
@@ -98,22 +94,16 @@ async def sse_client(
|
|||||||
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
|
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
|
||||||
sse.data
|
sse.data
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(f"Received server message: {message}")
|
||||||
f"Received server message: {message}"
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(
|
logger.error(f"Error parsing server message: {exc}")
|
||||||
f"Error parsing server message: {exc}"
|
|
||||||
)
|
|
||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
session_message = SessionMessage(message)
|
session_message = SessionMessage(message)
|
||||||
await read_stream_writer.send(session_message)
|
await read_stream_writer.send(session_message)
|
||||||
case _:
|
case _:
|
||||||
logger.warning(
|
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||||
f"Unknown SSE event: {sse.event}"
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(f"Error in sse_reader: {exc}")
|
logger.error(f"Error in sse_reader: {exc}")
|
||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
@@ -124,9 +114,7 @@ async def sse_client(
|
|||||||
try:
|
try:
|
||||||
async with write_stream_reader:
|
async with write_stream_reader:
|
||||||
async for session_message in write_stream_reader:
|
async for session_message in write_stream_reader:
|
||||||
logger.debug(
|
logger.debug(f"Sending client message: {session_message}")
|
||||||
f"Sending client message: {session_message}"
|
|
||||||
)
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
endpoint_url,
|
endpoint_url,
|
||||||
json=session_message.message.model_dump(
|
json=session_message.message.model_dump(
|
||||||
@@ -136,19 +124,14 @@ async def sse_client(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
logger.debug(
|
logger.debug("Client message sent successfully: " f"{response.status_code}")
|
||||||
"Client message sent successfully: "
|
|
||||||
f"{response.status_code}"
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(f"Error in post_writer: {exc}")
|
logger.error(f"Error in post_writer: {exc}")
|
||||||
finally:
|
finally:
|
||||||
await write_stream.aclose()
|
await write_stream.aclose()
|
||||||
|
|
||||||
endpoint_url = await tg.start(sse_reader)
|
endpoint_url = await tg.start(sse_reader)
|
||||||
logger.debug(
|
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
|
||||||
f"Starting post writer with endpoint URL: {endpoint_url}"
|
|
||||||
)
|
|
||||||
tg.start_soon(post_writer, endpoint_url)
|
tg.start_soon(post_writer, endpoint_url)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -115,11 +115,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
|
|||||||
process = await _create_platform_compatible_process(
|
process = await _create_platform_compatible_process(
|
||||||
command=command,
|
command=command,
|
||||||
args=server.args,
|
args=server.args,
|
||||||
env=(
|
env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()),
|
||||||
{**get_default_environment(), **server.env}
|
|
||||||
if server.env is not None
|
|
||||||
else get_default_environment()
|
|
||||||
),
|
|
||||||
errlog=errlog,
|
errlog=errlog,
|
||||||
cwd=server.cwd,
|
cwd=server.cwd,
|
||||||
)
|
)
|
||||||
@@ -163,9 +159,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
|
|||||||
try:
|
try:
|
||||||
async with write_stream_reader:
|
async with write_stream_reader:
|
||||||
async for session_message in write_stream_reader:
|
async for session_message in write_stream_reader:
|
||||||
json = session_message.message.model_dump_json(
|
json = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
|
||||||
by_alias=True, exclude_none=True
|
|
||||||
)
|
|
||||||
await process.stdin.send(
|
await process.stdin.send(
|
||||||
(json + "\n").encode(
|
(json + "\n").encode(
|
||||||
encoding=server.encoding,
|
encoding=server.encoding,
|
||||||
@@ -229,8 +223,6 @@ async def _create_platform_compatible_process(
|
|||||||
if sys.platform == "win32":
|
if sys.platform == "win32":
|
||||||
process = await create_windows_process(command, args, env, errlog, cwd)
|
process = await create_windows_process(command, args, env, errlog, cwd)
|
||||||
else:
|
else:
|
||||||
process = await anyio.open_process(
|
process = await anyio.open_process([command, *args], env=env, stderr=errlog, cwd=cwd)
|
||||||
[command, *args], env=env, stderr=errlog, cwd=cwd
|
|
||||||
)
|
|
||||||
|
|
||||||
return process
|
return process
|
||||||
|
|||||||
@@ -82,9 +82,7 @@ async def create_windows_process(
|
|||||||
return process
|
return process
|
||||||
except Exception:
|
except Exception:
|
||||||
# Don't raise, let's try to create the process without creation flags
|
# Don't raise, let's try to create the process without creation flags
|
||||||
process = await anyio.open_process(
|
process = await anyio.open_process([command, *args], env=env, stderr=errlog, cwd=cwd)
|
||||||
[command, *args], env=env, stderr=errlog, cwd=cwd
|
|
||||||
)
|
|
||||||
return process
|
return process
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -106,9 +106,7 @@ class StreamableHTTPTransport:
|
|||||||
**self.headers,
|
**self.headers,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _update_headers_with_session(
|
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
|
||||||
self, base_headers: dict[str, str]
|
|
||||||
) -> dict[str, str]:
|
|
||||||
"""Update headers with session ID if available."""
|
"""Update headers with session ID if available."""
|
||||||
headers = base_headers.copy()
|
headers = base_headers.copy()
|
||||||
if self.session_id:
|
if self.session_id:
|
||||||
@@ -117,17 +115,11 @@ class StreamableHTTPTransport:
|
|||||||
|
|
||||||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
||||||
"""Check if the message is an initialization request."""
|
"""Check if the message is an initialization request."""
|
||||||
return (
|
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
||||||
isinstance(message.root, JSONRPCRequest)
|
|
||||||
and message.root.method == "initialize"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
|
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
|
||||||
"""Check if the message is an initialized notification."""
|
"""Check if the message is an initialized notification."""
|
||||||
return (
|
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
|
||||||
isinstance(message.root, JSONRPCNotification)
|
|
||||||
and message.root.method == "notifications/initialized"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _maybe_extract_session_id_from_response(
|
def _maybe_extract_session_id_from_response(
|
||||||
self,
|
self,
|
||||||
@@ -153,9 +145,7 @@ class StreamableHTTPTransport:
|
|||||||
logger.debug(f"SSE message: {message}")
|
logger.debug(f"SSE message: {message}")
|
||||||
|
|
||||||
# If this is a response and we have original_request_id, replace it
|
# If this is a response and we have original_request_id, replace it
|
||||||
if original_request_id is not None and isinstance(
|
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
|
||||||
message.root, JSONRPCResponse | JSONRPCError
|
|
||||||
):
|
|
||||||
message.root.id = original_request_id
|
message.root.id = original_request_id
|
||||||
|
|
||||||
session_message = SessionMessage(message)
|
session_message = SessionMessage(message)
|
||||||
@@ -227,7 +217,8 @@ class StreamableHTTPTransport:
|
|||||||
self.url,
|
self.url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=httpx.Timeout(
|
timeout=httpx.Timeout(
|
||||||
self.timeout.total_seconds(), read=ctx.sse_read_timeout.total_seconds()
|
self.timeout.total_seconds(),
|
||||||
|
read=ctx.sse_read_timeout.total_seconds(),
|
||||||
),
|
),
|
||||||
) as event_source:
|
) as event_source:
|
||||||
event_source.response.raise_for_status()
|
event_source.response.raise_for_status()
|
||||||
@@ -298,9 +289,7 @@ class StreamableHTTPTransport:
|
|||||||
logger.error(f"Error parsing JSON response: {exc}")
|
logger.error(f"Error parsing JSON response: {exc}")
|
||||||
await read_stream_writer.send(exc)
|
await read_stream_writer.send(exc)
|
||||||
|
|
||||||
async def _handle_sse_response(
|
async def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
|
||||||
self, response: httpx.Response, ctx: RequestContext
|
|
||||||
) -> None:
|
|
||||||
"""Handle SSE response from the server."""
|
"""Handle SSE response from the server."""
|
||||||
try:
|
try:
|
||||||
event_source = EventSource(response)
|
event_source = EventSource(response)
|
||||||
@@ -308,11 +297,7 @@ class StreamableHTTPTransport:
|
|||||||
is_complete = await self._handle_sse_event(
|
is_complete = await self._handle_sse_event(
|
||||||
sse,
|
sse,
|
||||||
ctx.read_stream_writer,
|
ctx.read_stream_writer,
|
||||||
resumption_callback=(
|
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||||
ctx.metadata.on_resumption_token_update
|
|
||||||
if ctx.metadata
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
# If the SSE event indicates completion, like returning respose/error
|
# If the SSE event indicates completion, like returning respose/error
|
||||||
# break the loop
|
# break the loop
|
||||||
@@ -455,12 +440,8 @@ async def streamablehttp_client(
|
|||||||
"""
|
"""
|
||||||
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth)
|
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth)
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream[
|
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
|
||||||
SessionMessage | Exception
|
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
|
||||||
](0)
|
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage
|
|
||||||
](0)
|
|
||||||
|
|
||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
try:
|
try:
|
||||||
@@ -476,9 +457,7 @@ async def streamablehttp_client(
|
|||||||
) as client:
|
) as client:
|
||||||
# Define callbacks that need access to tg
|
# Define callbacks that need access to tg
|
||||||
def start_get_stream() -> None:
|
def start_get_stream() -> None:
|
||||||
tg.start_soon(
|
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
|
||||||
transport.handle_get_stream, client, read_stream_writer
|
|
||||||
)
|
|
||||||
|
|
||||||
tg.start_soon(
|
tg.start_soon(
|
||||||
transport.post_writer,
|
transport.post_writer,
|
||||||
|
|||||||
@@ -19,10 +19,7 @@ logger = logging.getLogger(__name__)
|
|||||||
async def websocket_client(
|
async def websocket_client(
|
||||||
url: str,
|
url: str,
|
||||||
) -> AsyncGenerator[
|
) -> AsyncGenerator[
|
||||||
tuple[
|
tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]],
|
||||||
MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
||||||
MemoryObjectSendStream[SessionMessage],
|
|
||||||
],
|
|
||||||
None,
|
None,
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
@@ -74,9 +71,7 @@ async def websocket_client(
|
|||||||
async with write_stream_reader:
|
async with write_stream_reader:
|
||||||
async for session_message in write_stream_reader:
|
async for session_message in write_stream_reader:
|
||||||
# Convert to a dict, then to JSON
|
# Convert to a dict, then to JSON
|
||||||
msg_dict = session_message.message.model_dump(
|
msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
by_alias=True, mode="json", exclude_none=True
|
|
||||||
)
|
|
||||||
await ws.send(json.dumps(msg_dict))
|
await ws.send(json.dumps(msg_dict))
|
||||||
|
|
||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
|
|||||||
@@ -2,7 +2,4 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
|
|
||||||
def stringify_pydantic_error(validation_error: ValidationError) -> str:
|
def stringify_pydantic_error(validation_error: ValidationError) -> str:
|
||||||
return "\n".join(
|
return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors())
|
||||||
f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}"
|
|
||||||
for e in validation_error.errors()
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ from starlette.datastructures import FormData, QueryParams
|
|||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import RedirectResponse, Response
|
from starlette.responses import RedirectResponse, Response
|
||||||
|
|
||||||
from mcp.server.auth.errors import (
|
from mcp.server.auth.errors import stringify_pydantic_error
|
||||||
stringify_pydantic_error,
|
|
||||||
)
|
|
||||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||||
from mcp.server.auth.provider import (
|
from mcp.server.auth.provider import (
|
||||||
AuthorizationErrorCode,
|
AuthorizationErrorCode,
|
||||||
@@ -18,10 +16,7 @@ from mcp.server.auth.provider import (
|
|||||||
OAuthAuthorizationServerProvider,
|
OAuthAuthorizationServerProvider,
|
||||||
construct_redirect_uri,
|
construct_redirect_uri,
|
||||||
)
|
)
|
||||||
from mcp.shared.auth import (
|
from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError
|
||||||
InvalidRedirectUriError,
|
|
||||||
InvalidScopeError,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -29,23 +24,16 @@ logger = logging.getLogger(__name__)
|
|||||||
class AuthorizationRequest(BaseModel):
|
class AuthorizationRequest(BaseModel):
|
||||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
|
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
|
||||||
client_id: str = Field(..., description="The client ID")
|
client_id: str = Field(..., description="The client ID")
|
||||||
redirect_uri: AnyUrl | None = Field(
|
redirect_uri: AnyUrl | None = Field(None, description="URL to redirect to after authorization")
|
||||||
None, description="URL to redirect to after authorization"
|
|
||||||
)
|
|
||||||
|
|
||||||
# see OAuthClientMetadata; we only support `code`
|
# see OAuthClientMetadata; we only support `code`
|
||||||
response_type: Literal["code"] = Field(
|
response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow")
|
||||||
..., description="Must be 'code' for authorization code flow"
|
|
||||||
)
|
|
||||||
code_challenge: str = Field(..., description="PKCE code challenge")
|
code_challenge: str = Field(..., description="PKCE code challenge")
|
||||||
code_challenge_method: Literal["S256"] = Field(
|
code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method, must be S256")
|
||||||
"S256", description="PKCE code challenge method, must be S256"
|
|
||||||
)
|
|
||||||
state: str | None = Field(None, description="Optional state parameter")
|
state: str | None = Field(None, description="Optional state parameter")
|
||||||
scope: str | None = Field(
|
scope: str | None = Field(
|
||||||
None,
|
None,
|
||||||
description="Optional scope; if specified, should be "
|
description="Optional scope; if specified, should be " "a space-separated list of scope strings",
|
||||||
"a space-separated list of scope strings",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -57,9 +45,7 @@ class AuthorizationErrorResponse(BaseModel):
|
|||||||
state: str | None = None
|
state: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def best_effort_extract_string(
|
def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> str | None:
|
||||||
key: str, params: None | FormData | QueryParams
|
|
||||||
) -> str | None:
|
|
||||||
if params is None:
|
if params is None:
|
||||||
return None
|
return None
|
||||||
value = params.get(key)
|
value = params.get(key)
|
||||||
@@ -138,9 +124,7 @@ class AuthorizationHandler:
|
|||||||
|
|
||||||
if redirect_uri and client:
|
if redirect_uri and client:
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
url=construct_redirect_uri(
|
url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)),
|
||||||
str(redirect_uri), **error_resp.model_dump(exclude_none=True)
|
|
||||||
),
|
|
||||||
status_code=302,
|
status_code=302,
|
||||||
headers={"Cache-Control": "no-store"},
|
headers={"Cache-Control": "no-store"},
|
||||||
)
|
)
|
||||||
@@ -172,9 +156,7 @@ class AuthorizationHandler:
|
|||||||
if e["loc"] == ("response_type",) and e["type"] == "literal_error":
|
if e["loc"] == ("response_type",) and e["type"] == "literal_error":
|
||||||
error = "unsupported_response_type"
|
error = "unsupported_response_type"
|
||||||
break
|
break
|
||||||
return await error_response(
|
return await error_response(error, stringify_pydantic_error(validation_error))
|
||||||
error, stringify_pydantic_error(validation_error)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get client information
|
# Get client information
|
||||||
client = await self.provider.get_client(
|
client = await self.provider.get_client(
|
||||||
@@ -229,16 +211,9 @@ class AuthorizationHandler:
|
|||||||
)
|
)
|
||||||
except AuthorizeError as e:
|
except AuthorizeError as e:
|
||||||
# Handle authorization errors as defined in RFC 6749 Section 4.1.2.1
|
# Handle authorization errors as defined in RFC 6749 Section 4.1.2.1
|
||||||
return await error_response(
|
return await error_response(error=e.error, error_description=e.error_description)
|
||||||
error=e.error,
|
|
||||||
error_description=e.error_description,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as validation_error:
|
except Exception as validation_error:
|
||||||
# Catch-all for unexpected errors
|
# Catch-all for unexpected errors
|
||||||
logger.exception(
|
logger.exception("Unexpected error in authorization_handler", exc_info=validation_error)
|
||||||
"Unexpected error in authorization_handler", exc_info=validation_error
|
return await error_response(error="server_error", error_description="An unexpected error occurred")
|
||||||
)
|
|
||||||
return await error_response(
|
|
||||||
error="server_error", error_description="An unexpected error occurred"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -10,11 +10,7 @@ from starlette.responses import Response
|
|||||||
|
|
||||||
from mcp.server.auth.errors import stringify_pydantic_error
|
from mcp.server.auth.errors import stringify_pydantic_error
|
||||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||||
from mcp.server.auth.provider import (
|
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, RegistrationError, RegistrationErrorCode
|
||||||
OAuthAuthorizationServerProvider,
|
|
||||||
RegistrationError,
|
|
||||||
RegistrationErrorCode,
|
|
||||||
)
|
|
||||||
from mcp.server.auth.settings import ClientRegistrationOptions
|
from mcp.server.auth.settings import ClientRegistrationOptions
|
||||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
|
||||||
|
|
||||||
@@ -60,9 +56,7 @@ class RegistrationHandler:
|
|||||||
|
|
||||||
if client_metadata.scope is None and self.options.default_scopes is not None:
|
if client_metadata.scope is None and self.options.default_scopes is not None:
|
||||||
client_metadata.scope = " ".join(self.options.default_scopes)
|
client_metadata.scope = " ".join(self.options.default_scopes)
|
||||||
elif (
|
elif client_metadata.scope is not None and self.options.valid_scopes is not None:
|
||||||
client_metadata.scope is not None and self.options.valid_scopes is not None
|
|
||||||
):
|
|
||||||
requested_scopes = set(client_metadata.scope.split())
|
requested_scopes = set(client_metadata.scope.split())
|
||||||
valid_scopes = set(self.options.valid_scopes)
|
valid_scopes = set(self.options.valid_scopes)
|
||||||
if not requested_scopes.issubset(valid_scopes):
|
if not requested_scopes.issubset(valid_scopes):
|
||||||
@@ -78,8 +72,7 @@ class RegistrationHandler:
|
|||||||
return PydanticJSONResponse(
|
return PydanticJSONResponse(
|
||||||
content=RegistrationErrorResponse(
|
content=RegistrationErrorResponse(
|
||||||
error="invalid_client_metadata",
|
error="invalid_client_metadata",
|
||||||
error_description="grant_types must be authorization_code "
|
error_description="grant_types must be authorization_code " "and refresh_token",
|
||||||
"and refresh_token",
|
|
||||||
),
|
),
|
||||||
status_code=400,
|
status_code=400,
|
||||||
)
|
)
|
||||||
@@ -122,8 +115,6 @@ class RegistrationHandler:
|
|||||||
except RegistrationError as e:
|
except RegistrationError as e:
|
||||||
# Handle registration errors as defined in RFC 7591 Section 3.2.2
|
# Handle registration errors as defined in RFC 7591 Section 3.2.2
|
||||||
return PydanticJSONResponse(
|
return PydanticJSONResponse(
|
||||||
content=RegistrationErrorResponse(
|
content=RegistrationErrorResponse(error=e.error, error_description=e.error_description),
|
||||||
error=e.error, error_description=e.error_description
|
|
||||||
),
|
|
||||||
status_code=400,
|
status_code=400,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,15 +10,8 @@ from mcp.server.auth.errors import (
|
|||||||
stringify_pydantic_error,
|
stringify_pydantic_error,
|
||||||
)
|
)
|
||||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||||
from mcp.server.auth.middleware.client_auth import (
|
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
|
||||||
AuthenticationError,
|
from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, RefreshToken
|
||||||
ClientAuthenticator,
|
|
||||||
)
|
|
||||||
from mcp.server.auth.provider import (
|
|
||||||
AccessToken,
|
|
||||||
OAuthAuthorizationServerProvider,
|
|
||||||
RefreshToken,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RevocationRequest(BaseModel):
|
class RevocationRequest(BaseModel):
|
||||||
|
|||||||
@@ -7,19 +7,10 @@ from typing import Annotated, Any, Literal
|
|||||||
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
|
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from mcp.server.auth.errors import (
|
from mcp.server.auth.errors import stringify_pydantic_error
|
||||||
stringify_pydantic_error,
|
|
||||||
)
|
|
||||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||||
from mcp.server.auth.middleware.client_auth import (
|
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
|
||||||
AuthenticationError,
|
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenError, TokenErrorCode
|
||||||
ClientAuthenticator,
|
|
||||||
)
|
|
||||||
from mcp.server.auth.provider import (
|
|
||||||
OAuthAuthorizationServerProvider,
|
|
||||||
TokenError,
|
|
||||||
TokenErrorCode,
|
|
||||||
)
|
|
||||||
from mcp.shared.auth import OAuthToken
|
from mcp.shared.auth import OAuthToken
|
||||||
|
|
||||||
|
|
||||||
@@ -27,9 +18,7 @@ class AuthorizationCodeRequest(BaseModel):
|
|||||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
|
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
|
||||||
grant_type: Literal["authorization_code"]
|
grant_type: Literal["authorization_code"]
|
||||||
code: str = Field(..., description="The authorization code")
|
code: str = Field(..., description="The authorization code")
|
||||||
redirect_uri: AnyUrl | None = Field(
|
redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize")
|
||||||
None, description="Must be the same as redirect URI provided in /authorize"
|
|
||||||
)
|
|
||||||
client_id: str
|
client_id: str
|
||||||
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
|
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
|
||||||
client_secret: str | None = None
|
client_secret: str | None = None
|
||||||
@@ -127,8 +116,7 @@ class TokenHandler:
|
|||||||
TokenErrorResponse(
|
TokenErrorResponse(
|
||||||
error="unsupported_grant_type",
|
error="unsupported_grant_type",
|
||||||
error_description=(
|
error_description=(
|
||||||
f"Unsupported grant type (supported grant types are "
|
f"Unsupported grant type (supported grant types are " f"{client_info.grant_types})"
|
||||||
f"{client_info.grant_types})"
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -137,9 +125,7 @@ class TokenHandler:
|
|||||||
|
|
||||||
match token_request:
|
match token_request:
|
||||||
case AuthorizationCodeRequest():
|
case AuthorizationCodeRequest():
|
||||||
auth_code = await self.provider.load_authorization_code(
|
auth_code = await self.provider.load_authorization_code(client_info, token_request.code)
|
||||||
client_info, token_request.code
|
|
||||||
)
|
|
||||||
if auth_code is None or auth_code.client_id != token_request.client_id:
|
if auth_code is None or auth_code.client_id != token_request.client_id:
|
||||||
# if code belongs to different client, pretend it doesn't exist
|
# if code belongs to different client, pretend it doesn't exist
|
||||||
return self.response(
|
return self.response(
|
||||||
@@ -169,18 +155,13 @@ class TokenHandler:
|
|||||||
return self.response(
|
return self.response(
|
||||||
TokenErrorResponse(
|
TokenErrorResponse(
|
||||||
error="invalid_request",
|
error="invalid_request",
|
||||||
error_description=(
|
error_description=("redirect_uri did not match the one " "used when creating auth code"),
|
||||||
"redirect_uri did not match the one "
|
|
||||||
"used when creating auth code"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify PKCE code verifier
|
# Verify PKCE code verifier
|
||||||
sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()
|
sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()
|
||||||
hashed_code_verifier = (
|
hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=")
|
||||||
base64.urlsafe_b64encode(sha256).decode().rstrip("=")
|
|
||||||
)
|
|
||||||
|
|
||||||
if hashed_code_verifier != auth_code.code_challenge:
|
if hashed_code_verifier != auth_code.code_challenge:
|
||||||
# see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
|
# see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
|
||||||
@@ -193,9 +174,7 @@ class TokenHandler:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Exchange authorization code for tokens
|
# Exchange authorization code for tokens
|
||||||
tokens = await self.provider.exchange_authorization_code(
|
tokens = await self.provider.exchange_authorization_code(client_info, auth_code)
|
||||||
client_info, auth_code
|
|
||||||
)
|
|
||||||
except TokenError as e:
|
except TokenError as e:
|
||||||
return self.response(
|
return self.response(
|
||||||
TokenErrorResponse(
|
TokenErrorResponse(
|
||||||
@@ -205,13 +184,8 @@ class TokenHandler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
case RefreshTokenRequest():
|
case RefreshTokenRequest():
|
||||||
refresh_token = await self.provider.load_refresh_token(
|
refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token)
|
||||||
client_info, token_request.refresh_token
|
if refresh_token is None or refresh_token.client_id != token_request.client_id:
|
||||||
)
|
|
||||||
if (
|
|
||||||
refresh_token is None
|
|
||||||
or refresh_token.client_id != token_request.client_id
|
|
||||||
):
|
|
||||||
# if token belongs to different client, pretend it doesn't exist
|
# if token belongs to different client, pretend it doesn't exist
|
||||||
return self.response(
|
return self.response(
|
||||||
TokenErrorResponse(
|
TokenErrorResponse(
|
||||||
@@ -230,29 +204,20 @@ class TokenHandler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Parse scopes if provided
|
# Parse scopes if provided
|
||||||
scopes = (
|
scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes
|
||||||
token_request.scope.split(" ")
|
|
||||||
if token_request.scope
|
|
||||||
else refresh_token.scopes
|
|
||||||
)
|
|
||||||
|
|
||||||
for scope in scopes:
|
for scope in scopes:
|
||||||
if scope not in refresh_token.scopes:
|
if scope not in refresh_token.scopes:
|
||||||
return self.response(
|
return self.response(
|
||||||
TokenErrorResponse(
|
TokenErrorResponse(
|
||||||
error="invalid_scope",
|
error="invalid_scope",
|
||||||
error_description=(
|
error_description=(f"cannot request scope `{scope}` " "not provided by refresh token"),
|
||||||
f"cannot request scope `{scope}` "
|
|
||||||
"not provided by refresh token"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Exchange refresh token for new tokens
|
# Exchange refresh token for new tokens
|
||||||
tokens = await self.provider.exchange_refresh_token(
|
tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes)
|
||||||
client_info, refresh_token, scopes
|
|
||||||
)
|
|
||||||
except TokenError as e:
|
except TokenError as e:
|
||||||
return self.response(
|
return self.response(
|
||||||
TokenErrorResponse(
|
TokenErrorResponse(
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ from mcp.server.auth.provider import AccessToken
|
|||||||
|
|
||||||
# Create a contextvar to store the authenticated user
|
# Create a contextvar to store the authenticated user
|
||||||
# The default is None, indicating no authenticated user is present
|
# The default is None, indicating no authenticated user is present
|
||||||
auth_context_var = contextvars.ContextVar[AuthenticatedUser | None](
|
auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]("auth_context", default=None)
|
||||||
"auth_context", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_access_token() -> AccessToken | None:
|
def get_access_token() -> AccessToken | None:
|
||||||
|
|||||||
@@ -1,11 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from starlette.authentication import (
|
from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser
|
||||||
AuthCredentials,
|
|
||||||
AuthenticationBackend,
|
|
||||||
SimpleUser,
|
|
||||||
)
|
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
from starlette.requests import HTTPConnection
|
from starlette.requests import HTTPConnection
|
||||||
from starlette.types import Receive, Scope, Send
|
from starlette.types import Receive, Scope, Send
|
||||||
@@ -35,11 +31,7 @@ class BearerAuthBackend(AuthenticationBackend):
|
|||||||
|
|
||||||
async def authenticate(self, conn: HTTPConnection):
|
async def authenticate(self, conn: HTTPConnection):
|
||||||
auth_header = next(
|
auth_header = next(
|
||||||
(
|
(conn.headers.get(key) for key in conn.headers if key.lower() == "authorization"),
|
||||||
conn.headers.get(key)
|
|
||||||
for key in conn.headers
|
|
||||||
if key.lower() == "authorization"
|
|
||||||
),
|
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||||
@@ -87,10 +79,7 @@ class RequireAuthMiddleware:
|
|||||||
|
|
||||||
for required_scope in self.required_scopes:
|
for required_scope in self.required_scopes:
|
||||||
# auth_credentials should always be provided; this is just paranoia
|
# auth_credentials should always be provided; this is just paranoia
|
||||||
if (
|
if auth_credentials is None or required_scope not in auth_credentials.scopes:
|
||||||
auth_credentials is None
|
|
||||||
or required_scope not in auth_credentials.scopes
|
|
||||||
):
|
|
||||||
raise HTTPException(status_code=403, detail="Insufficient scope")
|
raise HTTPException(status_code=403, detail="Insufficient scope")
|
||||||
|
|
||||||
await self.app(scope, receive, send)
|
await self.app(scope, receive, send)
|
||||||
|
|||||||
@@ -30,9 +30,7 @@ class ClientAuthenticator:
|
|||||||
"""
|
"""
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
|
|
||||||
async def authenticate(
|
async def authenticate(self, client_id: str, client_secret: str | None) -> OAuthClientInformationFull:
|
||||||
self, client_id: str, client_secret: str | None
|
|
||||||
) -> OAuthClientInformationFull:
|
|
||||||
# Look up client information
|
# Look up client information
|
||||||
client = await self.provider.get_client(client_id)
|
client = await self.provider.get_client(client_id)
|
||||||
if not client:
|
if not client:
|
||||||
@@ -47,10 +45,7 @@ class ClientAuthenticator:
|
|||||||
if client.client_secret != client_secret:
|
if client.client_secret != client_secret:
|
||||||
raise AuthenticationError("Invalid client_secret")
|
raise AuthenticationError("Invalid client_secret")
|
||||||
|
|
||||||
if (
|
if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()):
|
||||||
client.client_secret_expires_at
|
|
||||||
and client.client_secret_expires_at < int(time.time())
|
|
||||||
):
|
|
||||||
raise AuthenticationError("Client secret has expired")
|
raise AuthenticationError("Client secret has expired")
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|||||||
@@ -4,10 +4,7 @@ from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
|||||||
|
|
||||||
from pydantic import AnyUrl, BaseModel
|
from pydantic import AnyUrl, BaseModel
|
||||||
|
|
||||||
from mcp.shared.auth import (
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||||
OAuthClientInformationFull,
|
|
||||||
OAuthToken,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthorizationParams(BaseModel):
|
class AuthorizationParams(BaseModel):
|
||||||
@@ -96,9 +93,7 @@ RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken)
|
|||||||
AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken)
|
AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken)
|
||||||
|
|
||||||
|
|
||||||
class OAuthAuthorizationServerProvider(
|
class OAuthAuthorizationServerProvider(Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]):
|
||||||
Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]
|
|
||||||
):
|
|
||||||
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||||
"""
|
"""
|
||||||
Retrieves client information by client ID.
|
Retrieves client information by client ID.
|
||||||
@@ -129,9 +124,7 @@ class OAuthAuthorizationServerProvider(
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
async def authorize(
|
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
|
||||||
self, client: OAuthClientInformationFull, params: AuthorizationParams
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Called as part of the /authorize endpoint, and returns a URL that the client
|
Called as part of the /authorize endpoint, and returns a URL that the client
|
||||||
will be redirected to.
|
will be redirected to.
|
||||||
@@ -207,9 +200,7 @@ class OAuthAuthorizationServerProvider(
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
async def load_refresh_token(
|
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshTokenT | None:
|
||||||
self, client: OAuthClientInformationFull, refresh_token: str
|
|
||||||
) -> RefreshTokenT | None:
|
|
||||||
"""
|
"""
|
||||||
Loads a RefreshToken by its token string.
|
Loads a RefreshToken by its token string.
|
||||||
|
|
||||||
|
|||||||
@@ -31,11 +31,7 @@ def validate_issuer_url(url: AnyHttpUrl):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# RFC 8414 requires HTTPS, but we allow localhost HTTP for testing
|
# RFC 8414 requires HTTPS, but we allow localhost HTTP for testing
|
||||||
if (
|
if url.scheme != "https" and url.host != "localhost" and not url.host.startswith("127.0.0.1"):
|
||||||
url.scheme != "https"
|
|
||||||
and url.host != "localhost"
|
|
||||||
and not url.host.startswith("127.0.0.1")
|
|
||||||
):
|
|
||||||
raise ValueError("Issuer URL must be HTTPS")
|
raise ValueError("Issuer URL must be HTTPS")
|
||||||
|
|
||||||
# No fragments or query parameters allowed
|
# No fragments or query parameters allowed
|
||||||
@@ -73,9 +69,7 @@ def create_auth_routes(
|
|||||||
) -> list[Route]:
|
) -> list[Route]:
|
||||||
validate_issuer_url(issuer_url)
|
validate_issuer_url(issuer_url)
|
||||||
|
|
||||||
client_registration_options = (
|
client_registration_options = client_registration_options or ClientRegistrationOptions()
|
||||||
client_registration_options or ClientRegistrationOptions()
|
|
||||||
)
|
|
||||||
revocation_options = revocation_options or RevocationOptions()
|
revocation_options = revocation_options or RevocationOptions()
|
||||||
metadata = build_metadata(
|
metadata = build_metadata(
|
||||||
issuer_url,
|
issuer_url,
|
||||||
@@ -177,15 +171,11 @@ def build_metadata(
|
|||||||
|
|
||||||
# Add registration endpoint if supported
|
# Add registration endpoint if supported
|
||||||
if client_registration_options.enabled:
|
if client_registration_options.enabled:
|
||||||
metadata.registration_endpoint = AnyHttpUrl(
|
metadata.registration_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REGISTRATION_PATH)
|
||||||
str(issuer_url).rstrip("/") + REGISTRATION_PATH
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add revocation endpoint if supported
|
# Add revocation endpoint if supported
|
||||||
if revocation_options.enabled:
|
if revocation_options.enabled:
|
||||||
metadata.revocation_endpoint = AnyHttpUrl(
|
metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH)
|
||||||
str(issuer_url).rstrip("/") + REVOCATION_PATH
|
|
||||||
)
|
|
||||||
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"]
|
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"]
|
||||||
|
|
||||||
return metadata
|
return metadata
|
||||||
|
|||||||
@@ -15,8 +15,7 @@ class RevocationOptions(BaseModel):
|
|||||||
class AuthSettings(BaseModel):
|
class AuthSettings(BaseModel):
|
||||||
issuer_url: AnyHttpUrl = Field(
|
issuer_url: AnyHttpUrl = Field(
|
||||||
...,
|
...,
|
||||||
description="URL advertised as OAuth issuer; this should be the URL the server "
|
description="URL advertised as OAuth issuer; this should be the URL the server " "is reachable at",
|
||||||
"is reachable at",
|
|
||||||
)
|
)
|
||||||
service_documentation_url: AnyHttpUrl | None = None
|
service_documentation_url: AnyHttpUrl | None = None
|
||||||
client_registration_options: ClientRegistrationOptions | None = None
|
client_registration_options: ClientRegistrationOptions | None = None
|
||||||
|
|||||||
@@ -42,13 +42,9 @@ class AssistantMessage(Message):
|
|||||||
super().__init__(content=content, **kwargs)
|
super().__init__(content=content, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
message_validator = TypeAdapter[UserMessage | AssistantMessage](
|
message_validator = TypeAdapter[UserMessage | AssistantMessage](UserMessage | AssistantMessage)
|
||||||
UserMessage | AssistantMessage
|
|
||||||
)
|
|
||||||
|
|
||||||
SyncPromptResult = (
|
SyncPromptResult = str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
|
||||||
str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
|
|
||||||
)
|
|
||||||
PromptResult = SyncPromptResult | Awaitable[SyncPromptResult]
|
PromptResult = SyncPromptResult | Awaitable[SyncPromptResult]
|
||||||
|
|
||||||
|
|
||||||
@@ -56,24 +52,16 @@ class PromptArgument(BaseModel):
|
|||||||
"""An argument that can be passed to a prompt."""
|
"""An argument that can be passed to a prompt."""
|
||||||
|
|
||||||
name: str = Field(description="Name of the argument")
|
name: str = Field(description="Name of the argument")
|
||||||
description: str | None = Field(
|
description: str | None = Field(None, description="Description of what the argument does")
|
||||||
None, description="Description of what the argument does"
|
required: bool = Field(default=False, description="Whether the argument is required")
|
||||||
)
|
|
||||||
required: bool = Field(
|
|
||||||
default=False, description="Whether the argument is required"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Prompt(BaseModel):
|
class Prompt(BaseModel):
|
||||||
"""A prompt template that can be rendered with parameters."""
|
"""A prompt template that can be rendered with parameters."""
|
||||||
|
|
||||||
name: str = Field(description="Name of the prompt")
|
name: str = Field(description="Name of the prompt")
|
||||||
description: str | None = Field(
|
description: str | None = Field(None, description="Description of what the prompt does")
|
||||||
None, description="Description of what the prompt does"
|
arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt")
|
||||||
)
|
|
||||||
arguments: list[PromptArgument] | None = Field(
|
|
||||||
None, description="Arguments that can be passed to the prompt"
|
|
||||||
)
|
|
||||||
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
|
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -154,14 +142,10 @@ class Prompt(BaseModel):
|
|||||||
content = TextContent(type="text", text=msg)
|
content = TextContent(type="text", text=msg)
|
||||||
messages.append(UserMessage(content=content))
|
messages.append(UserMessage(content=content))
|
||||||
else:
|
else:
|
||||||
content = pydantic_core.to_json(
|
content = pydantic_core.to_json(msg, fallback=str, indent=2).decode()
|
||||||
msg, fallback=str, indent=2
|
|
||||||
).decode()
|
|
||||||
messages.append(Message(role="user", content=content))
|
messages.append(Message(role="user", content=content))
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError(
|
raise ValueError(f"Could not convert prompt result to message: {msg}")
|
||||||
f"Could not convert prompt result to message: {msg}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -39,9 +39,7 @@ class PromptManager:
|
|||||||
self._prompts[prompt.name] = prompt
|
self._prompts[prompt.name] = prompt
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
async def render_prompt(
|
async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]:
|
||||||
self, name: str, arguments: dict[str, Any] | None = None
|
|
||||||
) -> list[Message]:
|
|
||||||
"""Render a prompt by name with arguments."""
|
"""Render a prompt by name with arguments."""
|
||||||
prompt = self.get_prompt(name)
|
prompt = self.get_prompt(name)
|
||||||
if not prompt:
|
if not prompt:
|
||||||
|
|||||||
@@ -19,13 +19,9 @@ class Resource(BaseModel, abc.ABC):
|
|||||||
|
|
||||||
model_config = ConfigDict(validate_default=True)
|
model_config = ConfigDict(validate_default=True)
|
||||||
|
|
||||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(
|
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(default=..., description="URI of the resource")
|
||||||
default=..., description="URI of the resource"
|
|
||||||
)
|
|
||||||
name: str | None = Field(description="Name of the resource", default=None)
|
name: str | None = Field(description="Name of the resource", default=None)
|
||||||
description: str | None = Field(
|
description: str | None = Field(description="Description of the resource", default=None)
|
||||||
description="Description of the resource", default=None
|
|
||||||
)
|
|
||||||
mime_type: str = Field(
|
mime_type: str = Field(
|
||||||
default="text/plain",
|
default="text/plain",
|
||||||
description="MIME type of the resource content",
|
description="MIME type of the resource content",
|
||||||
|
|||||||
@@ -15,18 +15,12 @@ from mcp.server.fastmcp.resources.types import FunctionResource, Resource
|
|||||||
class ResourceTemplate(BaseModel):
|
class ResourceTemplate(BaseModel):
|
||||||
"""A template for dynamically creating resources."""
|
"""A template for dynamically creating resources."""
|
||||||
|
|
||||||
uri_template: str = Field(
|
uri_template: str = Field(description="URI template with parameters (e.g. weather://{city}/current)")
|
||||||
description="URI template with parameters (e.g. weather://{city}/current)"
|
|
||||||
)
|
|
||||||
name: str = Field(description="Name of the resource")
|
name: str = Field(description="Name of the resource")
|
||||||
description: str | None = Field(description="Description of what the resource does")
|
description: str | None = Field(description="Description of what the resource does")
|
||||||
mime_type: str = Field(
|
mime_type: str = Field(default="text/plain", description="MIME type of the resource content")
|
||||||
default="text/plain", description="MIME type of the resource content"
|
|
||||||
)
|
|
||||||
fn: Callable[..., Any] = Field(exclude=True)
|
fn: Callable[..., Any] = Field(exclude=True)
|
||||||
parameters: dict[str, Any] = Field(
|
parameters: dict[str, Any] = Field(description="JSON schema for function parameters")
|
||||||
description="JSON schema for function parameters"
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_function(
|
def from_function(
|
||||||
|
|||||||
@@ -54,9 +54,7 @@ class FunctionResource(Resource):
|
|||||||
async def read(self) -> str | bytes:
|
async def read(self) -> str | bytes:
|
||||||
"""Read the resource by calling the wrapped function."""
|
"""Read the resource by calling the wrapped function."""
|
||||||
try:
|
try:
|
||||||
result = (
|
result = await self.fn() if inspect.iscoroutinefunction(self.fn) else self.fn()
|
||||||
await self.fn() if inspect.iscoroutinefunction(self.fn) else self.fn()
|
|
||||||
)
|
|
||||||
if isinstance(result, Resource):
|
if isinstance(result, Resource):
|
||||||
return await result.read()
|
return await result.read()
|
||||||
elif isinstance(result, bytes):
|
elif isinstance(result, bytes):
|
||||||
@@ -141,9 +139,7 @@ class HttpResource(Resource):
|
|||||||
"""A resource that reads from an HTTP endpoint."""
|
"""A resource that reads from an HTTP endpoint."""
|
||||||
|
|
||||||
url: str = Field(description="URL to fetch content from")
|
url: str = Field(description="URL to fetch content from")
|
||||||
mime_type: str = Field(
|
mime_type: str = Field(default="application/json", description="MIME type of the resource content")
|
||||||
default="application/json", description="MIME type of the resource content"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def read(self) -> str | bytes:
|
async def read(self) -> str | bytes:
|
||||||
"""Read the HTTP content."""
|
"""Read the HTTP content."""
|
||||||
@@ -157,15 +153,9 @@ class DirectoryResource(Resource):
|
|||||||
"""A resource that lists files in a directory."""
|
"""A resource that lists files in a directory."""
|
||||||
|
|
||||||
path: Path = Field(description="Path to the directory")
|
path: Path = Field(description="Path to the directory")
|
||||||
recursive: bool = Field(
|
recursive: bool = Field(default=False, description="Whether to list files recursively")
|
||||||
default=False, description="Whether to list files recursively"
|
pattern: str | None = Field(default=None, description="Optional glob pattern to filter files")
|
||||||
)
|
mime_type: str = Field(default="application/json", description="MIME type of the resource content")
|
||||||
pattern: str | None = Field(
|
|
||||||
default=None, description="Optional glob pattern to filter files"
|
|
||||||
)
|
|
||||||
mime_type: str = Field(
|
|
||||||
default="application/json", description="MIME type of the resource content"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pydantic.field_validator("path")
|
@pydantic.field_validator("path")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -184,16 +174,8 @@ class DirectoryResource(Resource):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if self.pattern:
|
if self.pattern:
|
||||||
return (
|
return list(self.path.glob(self.pattern)) if not self.recursive else list(self.path.rglob(self.pattern))
|
||||||
list(self.path.glob(self.pattern))
|
return list(self.path.glob("*")) if not self.recursive else list(self.path.rglob("*"))
|
||||||
if not self.recursive
|
|
||||||
else list(self.path.rglob(self.pattern))
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
list(self.path.glob("*"))
|
|
||||||
if not self.recursive
|
|
||||||
else list(self.path.rglob("*"))
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error listing directory {self.path}: {e}")
|
raise ValueError(f"Error listing directory {self.path}: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -97,9 +97,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
|||||||
|
|
||||||
# StreamableHTTP settings
|
# StreamableHTTP settings
|
||||||
json_response: bool = False
|
json_response: bool = False
|
||||||
stateless_http: bool = (
|
stateless_http: bool = False # If True, uses true stateless mode (new transport per request)
|
||||||
False # If True, uses true stateless mode (new transport per request)
|
|
||||||
)
|
|
||||||
|
|
||||||
# resource settings
|
# resource settings
|
||||||
warn_on_duplicate_resources: bool = True
|
warn_on_duplicate_resources: bool = True
|
||||||
@@ -115,9 +113,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
|||||||
description="List of dependencies to install in the server environment",
|
description="List of dependencies to install in the server environment",
|
||||||
)
|
)
|
||||||
|
|
||||||
lifespan: (
|
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None = Field(
|
||||||
Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None
|
None, description="Lifespan context manager"
|
||||||
) = Field(None, description="Lifespan context manager")
|
)
|
||||||
|
|
||||||
auth: AuthSettings | None = None
|
auth: AuthSettings | None = None
|
||||||
|
|
||||||
@@ -125,9 +123,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
|||||||
def lifespan_wrapper(
|
def lifespan_wrapper(
|
||||||
app: FastMCP,
|
app: FastMCP,
|
||||||
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
|
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
|
||||||
) -> Callable[
|
) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]]:
|
||||||
[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]
|
|
||||||
]:
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
|
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
|
||||||
async with lifespan(app) as context:
|
async with lifespan(app) as context:
|
||||||
@@ -141,8 +137,7 @@ class FastMCP:
|
|||||||
self,
|
self,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
instructions: str | None = None,
|
instructions: str | None = None,
|
||||||
auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None,
|
||||||
| None = None,
|
|
||||||
event_store: EventStore | None = None,
|
event_store: EventStore | None = None,
|
||||||
*,
|
*,
|
||||||
tools: list[Tool] | None = None,
|
tools: list[Tool] | None = None,
|
||||||
@@ -153,31 +148,18 @@ class FastMCP:
|
|||||||
self._mcp_server = MCPServer(
|
self._mcp_server = MCPServer(
|
||||||
name=name or "FastMCP",
|
name=name or "FastMCP",
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
lifespan=(
|
lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan),
|
||||||
lifespan_wrapper(self, self.settings.lifespan)
|
|
||||||
if self.settings.lifespan
|
|
||||||
else default_lifespan
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self._tool_manager = ToolManager(
|
|
||||||
tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
|
|
||||||
)
|
|
||||||
self._resource_manager = ResourceManager(
|
|
||||||
warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources
|
|
||||||
)
|
|
||||||
self._prompt_manager = PromptManager(
|
|
||||||
warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts
|
|
||||||
)
|
)
|
||||||
|
self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools)
|
||||||
|
self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources)
|
||||||
|
self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts)
|
||||||
if (self.settings.auth is not None) != (auth_server_provider is not None):
|
if (self.settings.auth is not None) != (auth_server_provider is not None):
|
||||||
# TODO: after we support separate authorization servers (see
|
# TODO: after we support separate authorization servers (see
|
||||||
# https://github.com/modelcontextprotocol/modelcontextprotocol/pull/284)
|
# https://github.com/modelcontextprotocol/modelcontextprotocol/pull/284)
|
||||||
# we should validate that if auth is enabled, we have either an
|
# we should validate that if auth is enabled, we have either an
|
||||||
# auth_server_provider to host our own authorization server,
|
# auth_server_provider to host our own authorization server,
|
||||||
# OR the URL of a 3rd party authorization server.
|
# OR the URL of a 3rd party authorization server.
|
||||||
raise ValueError(
|
raise ValueError("settings.auth must be specified if and only if auth_server_provider " "is specified")
|
||||||
"settings.auth must be specified if and only if auth_server_provider "
|
|
||||||
"is specified"
|
|
||||||
)
|
|
||||||
self._auth_server_provider = auth_server_provider
|
self._auth_server_provider = auth_server_provider
|
||||||
self._event_store = event_store
|
self._event_store = event_store
|
||||||
self._custom_starlette_routes: list[Route] = []
|
self._custom_starlette_routes: list[Route] = []
|
||||||
@@ -340,9 +322,7 @@ class FastMCP:
|
|||||||
description: Optional description of what the tool does
|
description: Optional description of what the tool does
|
||||||
annotations: Optional ToolAnnotations providing additional tool information
|
annotations: Optional ToolAnnotations providing additional tool information
|
||||||
"""
|
"""
|
||||||
self._tool_manager.add_tool(
|
self._tool_manager.add_tool(fn, name=name, description=description, annotations=annotations)
|
||||||
fn, name=name, description=description, annotations=annotations
|
|
||||||
)
|
|
||||||
|
|
||||||
def tool(
|
def tool(
|
||||||
self,
|
self,
|
||||||
@@ -379,14 +359,11 @@ class FastMCP:
|
|||||||
# Check if user passed function directly instead of calling decorator
|
# Check if user passed function directly instead of calling decorator
|
||||||
if callable(name):
|
if callable(name):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"The @tool decorator was used incorrectly. "
|
"The @tool decorator was used incorrectly. " "Did you forget to call it? Use @tool() instead of @tool"
|
||||||
"Did you forget to call it? Use @tool() instead of @tool"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def decorator(fn: AnyFunction) -> AnyFunction:
|
def decorator(fn: AnyFunction) -> AnyFunction:
|
||||||
self.add_tool(
|
self.add_tool(fn, name=name, description=description, annotations=annotations)
|
||||||
fn, name=name, description=description, annotations=annotations
|
|
||||||
)
|
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
@@ -462,8 +439,7 @@ class FastMCP:
|
|||||||
|
|
||||||
if uri_params != func_params:
|
if uri_params != func_params:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Mismatch between URI parameters {uri_params} "
|
f"Mismatch between URI parameters {uri_params} " f"and function parameters {func_params}"
|
||||||
f"and function parameters {func_params}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register as template
|
# Register as template
|
||||||
@@ -496,9 +472,7 @@ class FastMCP:
|
|||||||
"""
|
"""
|
||||||
self._prompt_manager.add_prompt(prompt)
|
self._prompt_manager.add_prompt(prompt)
|
||||||
|
|
||||||
def prompt(
|
def prompt(self, name: str | None = None, description: str | None = None) -> Callable[[AnyFunction], AnyFunction]:
|
||||||
self, name: str | None = None, description: str | None = None
|
|
||||||
) -> Callable[[AnyFunction], AnyFunction]:
|
|
||||||
"""Decorator to register a prompt.
|
"""Decorator to register a prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -665,9 +639,7 @@ class FastMCP:
|
|||||||
self.settings.mount_path = mount_path
|
self.settings.mount_path = mount_path
|
||||||
|
|
||||||
# Create normalized endpoint considering the mount path
|
# Create normalized endpoint considering the mount path
|
||||||
normalized_message_endpoint = self._normalize_path(
|
normalized_message_endpoint = self._normalize_path(self.settings.mount_path, self.settings.message_path)
|
||||||
self.settings.mount_path, self.settings.message_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set up auth context and dependencies
|
# Set up auth context and dependencies
|
||||||
|
|
||||||
@@ -764,9 +736,7 @@ class FastMCP:
|
|||||||
routes.extend(self._custom_starlette_routes)
|
routes.extend(self._custom_starlette_routes)
|
||||||
|
|
||||||
# Create Starlette app with routes and middleware
|
# Create Starlette app with routes and middleware
|
||||||
return Starlette(
|
return Starlette(debug=self.settings.debug, routes=routes, middleware=middleware)
|
||||||
debug=self.settings.debug, routes=routes, middleware=middleware
|
|
||||||
)
|
|
||||||
|
|
||||||
def streamable_http_app(self) -> Starlette:
|
def streamable_http_app(self) -> Starlette:
|
||||||
"""Return an instance of the StreamableHTTP server app."""
|
"""Return an instance of the StreamableHTTP server app."""
|
||||||
@@ -783,9 +753,7 @@ class FastMCP:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create the ASGI handler
|
# Create the ASGI handler
|
||||||
async def handle_streamable_http(
|
async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
scope: Scope, receive: Receive, send: Send
|
|
||||||
) -> None:
|
|
||||||
await self.session_manager.handle_request(scope, receive, send)
|
await self.session_manager.handle_request(scope, receive, send)
|
||||||
|
|
||||||
# Create routes
|
# Create routes
|
||||||
@@ -861,9 +829,7 @@ class FastMCP:
|
|||||||
for prompt in prompts
|
for prompt in prompts
|
||||||
]
|
]
|
||||||
|
|
||||||
async def get_prompt(
|
async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult:
|
||||||
self, name: str, arguments: dict[str, Any] | None = None
|
|
||||||
) -> GetPromptResult:
|
|
||||||
"""Get a prompt by name with arguments."""
|
"""Get a prompt by name with arguments."""
|
||||||
try:
|
try:
|
||||||
messages = await self._prompt_manager.render_prompt(name, arguments)
|
messages = await self._prompt_manager.render_prompt(name, arguments)
|
||||||
@@ -936,9 +902,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
request_context: (
|
request_context: (RequestContext[ServerSessionT, LifespanContextT, RequestT] | None) = None,
|
||||||
RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
|
|
||||||
) = None,
|
|
||||||
fastmcp: FastMCP | None = None,
|
fastmcp: FastMCP | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
@@ -962,9 +926,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
|||||||
raise ValueError("Context is not available outside of a request")
|
raise ValueError("Context is not available outside of a request")
|
||||||
return self._request_context
|
return self._request_context
|
||||||
|
|
||||||
async def report_progress(
|
async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None:
|
||||||
self, progress: float, total: float | None = None, message: str | None = None
|
|
||||||
) -> None:
|
|
||||||
"""Report progress for the current operation.
|
"""Report progress for the current operation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -972,11 +934,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
|||||||
total: Optional total value e.g. 100
|
total: Optional total value e.g. 100
|
||||||
message: Optional message e.g. Starting render...
|
message: Optional message e.g. Starting render...
|
||||||
"""
|
"""
|
||||||
progress_token = (
|
progress_token = self.request_context.meta.progressToken if self.request_context.meta else None
|
||||||
self.request_context.meta.progressToken
|
|
||||||
if self.request_context.meta
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if progress_token is None:
|
if progress_token is None:
|
||||||
return
|
return
|
||||||
@@ -997,9 +955,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
|||||||
Returns:
|
Returns:
|
||||||
The resource content as either text or bytes
|
The resource content as either text or bytes
|
||||||
"""
|
"""
|
||||||
assert (
|
assert self._fastmcp is not None, "Context is not available outside of a request"
|
||||||
self._fastmcp is not None
|
|
||||||
), "Context is not available outside of a request"
|
|
||||||
return await self._fastmcp.read_resource(uri)
|
return await self._fastmcp.read_resource(uri)
|
||||||
|
|
||||||
async def log(
|
async def log(
|
||||||
@@ -1027,11 +983,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
|||||||
@property
|
@property
|
||||||
def client_id(self) -> str | None:
|
def client_id(self) -> str | None:
|
||||||
"""Get the client ID if available."""
|
"""Get the client ID if available."""
|
||||||
return (
|
return getattr(self.request_context.meta, "client_id", None) if self.request_context.meta else None
|
||||||
getattr(self.request_context.meta, "client_id", None)
|
|
||||||
if self.request_context.meta
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def request_id(self) -> str:
|
def request_id(self) -> str:
|
||||||
|
|||||||
@@ -25,16 +25,11 @@ class Tool(BaseModel):
|
|||||||
description: str = Field(description="Description of what the tool does")
|
description: str = Field(description="Description of what the tool does")
|
||||||
parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
|
parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
|
||||||
fn_metadata: FuncMetadata = Field(
|
fn_metadata: FuncMetadata = Field(
|
||||||
description="Metadata about the function including a pydantic model for tool"
|
description="Metadata about the function including a pydantic model for tool" " arguments"
|
||||||
" arguments"
|
|
||||||
)
|
)
|
||||||
is_async: bool = Field(description="Whether the tool is async")
|
is_async: bool = Field(description="Whether the tool is async")
|
||||||
context_kwarg: str | None = Field(
|
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
|
||||||
None, description="Name of the kwarg that should receive context"
|
annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool")
|
||||||
)
|
|
||||||
annotations: ToolAnnotations | None = Field(
|
|
||||||
None, description="Optional annotations for the tool"
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_function(
|
def from_function(
|
||||||
@@ -93,9 +88,7 @@ class Tool(BaseModel):
|
|||||||
self.fn,
|
self.fn,
|
||||||
self.is_async,
|
self.is_async,
|
||||||
arguments,
|
arguments,
|
||||||
{self.context_kwarg: context}
|
{self.context_kwarg: context} if self.context_kwarg is not None else None,
|
||||||
if self.context_kwarg is not None
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ToolError(f"Error executing tool {self.name}: {e}") from e
|
raise ToolError(f"Error executing tool {self.name}: {e}") from e
|
||||||
|
|||||||
@@ -50,9 +50,7 @@ class ToolManager:
|
|||||||
annotations: ToolAnnotations | None = None,
|
annotations: ToolAnnotations | None = None,
|
||||||
) -> Tool:
|
) -> Tool:
|
||||||
"""Add a tool to the server."""
|
"""Add a tool to the server."""
|
||||||
tool = Tool.from_function(
|
tool = Tool.from_function(fn, name=name, description=description, annotations=annotations)
|
||||||
fn, name=name, description=description, annotations=annotations
|
|
||||||
)
|
|
||||||
existing = self._tools.get(tool.name)
|
existing = self._tools.get(tool.name)
|
||||||
if existing:
|
if existing:
|
||||||
if self.warn_on_duplicate_tools:
|
if self.warn_on_duplicate_tools:
|
||||||
|
|||||||
@@ -102,9 +102,7 @@ class FuncMetadata(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def func_metadata(
|
def func_metadata(func: Callable[..., Any], skip_names: Sequence[str] = ()) -> FuncMetadata:
|
||||||
func: Callable[..., Any], skip_names: Sequence[str] = ()
|
|
||||||
) -> FuncMetadata:
|
|
||||||
"""Given a function, return metadata including a pydantic model representing its
|
"""Given a function, return metadata including a pydantic model representing its
|
||||||
signature.
|
signature.
|
||||||
|
|
||||||
@@ -131,9 +129,7 @@ def func_metadata(
|
|||||||
globalns = getattr(func, "__globals__", {})
|
globalns = getattr(func, "__globals__", {})
|
||||||
for param in params.values():
|
for param in params.values():
|
||||||
if param.name.startswith("_"):
|
if param.name.startswith("_"):
|
||||||
raise InvalidSignature(
|
raise InvalidSignature(f"Parameter {param.name} of {func.__name__} cannot start with '_'")
|
||||||
f"Parameter {param.name} of {func.__name__} cannot start with '_'"
|
|
||||||
)
|
|
||||||
if param.name in skip_names:
|
if param.name in skip_names:
|
||||||
continue
|
continue
|
||||||
annotation = param.annotation
|
annotation = param.annotation
|
||||||
@@ -142,11 +138,7 @@ def func_metadata(
|
|||||||
if annotation is None:
|
if annotation is None:
|
||||||
annotation = Annotated[
|
annotation = Annotated[
|
||||||
None,
|
None,
|
||||||
Field(
|
Field(default=param.default if param.default is not inspect.Parameter.empty else PydanticUndefined),
|
||||||
default=param.default
|
|
||||||
if param.default is not inspect.Parameter.empty
|
|
||||||
else PydanticUndefined
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Untyped field
|
# Untyped field
|
||||||
@@ -160,9 +152,7 @@ def func_metadata(
|
|||||||
|
|
||||||
field_info = FieldInfo.from_annotated_attribute(
|
field_info = FieldInfo.from_annotated_attribute(
|
||||||
_get_typed_annotation(annotation, globalns),
|
_get_typed_annotation(annotation, globalns),
|
||||||
param.default
|
param.default if param.default is not inspect.Parameter.empty else PydanticUndefined,
|
||||||
if param.default is not inspect.Parameter.empty
|
|
||||||
else PydanticUndefined,
|
|
||||||
)
|
)
|
||||||
dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info)
|
dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info)
|
||||||
continue
|
continue
|
||||||
@@ -177,9 +167,7 @@ def func_metadata(
|
|||||||
|
|
||||||
|
|
||||||
def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
|
def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
|
||||||
def try_eval_type(
|
def try_eval_type(value: Any, globalns: dict[str, Any], localns: dict[str, Any]) -> tuple[Any, bool]:
|
||||||
value: Any, globalns: dict[str, Any], localns: dict[str, Any]
|
|
||||||
) -> tuple[Any, bool]:
|
|
||||||
try:
|
try:
|
||||||
return eval_type_backport(value, globalns, localns), True
|
return eval_type_backport(value, globalns, localns), True
|
||||||
except NameError:
|
except NameError:
|
||||||
|
|||||||
@@ -95,9 +95,7 @@ LifespanResultT = TypeVar("LifespanResultT")
|
|||||||
RequestT = TypeVar("RequestT", default=Any)
|
RequestT = TypeVar("RequestT", default=Any)
|
||||||
|
|
||||||
# This will be properly typed in each Server instance's context
|
# This will be properly typed in each Server instance's context
|
||||||
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = (
|
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx")
|
||||||
contextvars.ContextVar("request_ctx")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class NotificationOptions:
|
class NotificationOptions:
|
||||||
@@ -140,9 +138,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
self.version = version
|
self.version = version
|
||||||
self.instructions = instructions
|
self.instructions = instructions
|
||||||
self.lifespan = lifespan
|
self.lifespan = lifespan
|
||||||
self.request_handlers: dict[
|
self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = {
|
||||||
type, Callable[..., Awaitable[types.ServerResult]]
|
|
||||||
] = {
|
|
||||||
types.PingRequest: _ping_handler,
|
types.PingRequest: _ping_handler,
|
||||||
}
|
}
|
||||||
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
||||||
@@ -189,9 +185,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
|
|
||||||
# Set prompt capabilities if handler exists
|
# Set prompt capabilities if handler exists
|
||||||
if types.ListPromptsRequest in self.request_handlers:
|
if types.ListPromptsRequest in self.request_handlers:
|
||||||
prompts_capability = types.PromptsCapability(
|
prompts_capability = types.PromptsCapability(listChanged=notification_options.prompts_changed)
|
||||||
listChanged=notification_options.prompts_changed
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set resource capabilities if handler exists
|
# Set resource capabilities if handler exists
|
||||||
if types.ListResourcesRequest in self.request_handlers:
|
if types.ListResourcesRequest in self.request_handlers:
|
||||||
@@ -201,9 +195,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
|
|
||||||
# Set tool capabilities if handler exists
|
# Set tool capabilities if handler exists
|
||||||
if types.ListToolsRequest in self.request_handlers:
|
if types.ListToolsRequest in self.request_handlers:
|
||||||
tools_capability = types.ToolsCapability(
|
tools_capability = types.ToolsCapability(listChanged=notification_options.tools_changed)
|
||||||
listChanged=notification_options.tools_changed
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set logging capabilities if handler exists
|
# Set logging capabilities if handler exists
|
||||||
if types.SetLevelRequest in self.request_handlers:
|
if types.SetLevelRequest in self.request_handlers:
|
||||||
@@ -239,9 +231,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
|
|
||||||
def get_prompt(self):
|
def get_prompt(self):
|
||||||
def decorator(
|
def decorator(
|
||||||
func: Callable[
|
func: Callable[[str, dict[str, str] | None], Awaitable[types.GetPromptResult]],
|
||||||
[str, dict[str, str] | None], Awaitable[types.GetPromptResult]
|
|
||||||
],
|
|
||||||
):
|
):
|
||||||
logger.debug("Registering handler for GetPromptRequest")
|
logger.debug("Registering handler for GetPromptRequest")
|
||||||
|
|
||||||
@@ -260,9 +250,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
|
|
||||||
async def handler(_: Any):
|
async def handler(_: Any):
|
||||||
resources = await func()
|
resources = await func()
|
||||||
return types.ServerResult(
|
return types.ServerResult(types.ListResourcesResult(resources=resources))
|
||||||
types.ListResourcesResult(resources=resources)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.request_handlers[types.ListResourcesRequest] = handler
|
self.request_handlers[types.ListResourcesRequest] = handler
|
||||||
return func
|
return func
|
||||||
@@ -275,9 +263,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
|
|
||||||
async def handler(_: Any):
|
async def handler(_: Any):
|
||||||
templates = await func()
|
templates = await func()
|
||||||
return types.ServerResult(
|
return types.ServerResult(types.ListResourceTemplatesResult(resourceTemplates=templates))
|
||||||
types.ListResourceTemplatesResult(resourceTemplates=templates)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.request_handlers[types.ListResourceTemplatesRequest] = handler
|
self.request_handlers[types.ListResourceTemplatesRequest] = handler
|
||||||
return func
|
return func
|
||||||
@@ -286,9 +272,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
|
|
||||||
def read_resource(self):
|
def read_resource(self):
|
||||||
def decorator(
|
def decorator(
|
||||||
func: Callable[
|
func: Callable[[AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]]],
|
||||||
[AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]]
|
|
||||||
],
|
|
||||||
):
|
):
|
||||||
logger.debug("Registering handler for ReadResourceRequest")
|
logger.debug("Registering handler for ReadResourceRequest")
|
||||||
|
|
||||||
@@ -323,8 +307,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
content = create_content(data, None)
|
content = create_content(data, None)
|
||||||
case Iterable() as contents:
|
case Iterable() as contents:
|
||||||
contents_list = [
|
contents_list = [
|
||||||
create_content(content_item.content, content_item.mime_type)
|
create_content(content_item.content, content_item.mime_type) for content_item in contents
|
||||||
for content_item in contents
|
|
||||||
]
|
]
|
||||||
return types.ServerResult(
|
return types.ServerResult(
|
||||||
types.ReadResourceResult(
|
types.ReadResourceResult(
|
||||||
@@ -332,9 +315,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(
|
raise ValueError(f"Unexpected return type from read_resource: {type(result)}")
|
||||||
f"Unexpected return type from read_resource: {type(result)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return types.ServerResult(
|
return types.ServerResult(
|
||||||
types.ReadResourceResult(
|
types.ReadResourceResult(
|
||||||
@@ -404,12 +385,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
func: Callable[
|
func: Callable[
|
||||||
...,
|
...,
|
||||||
Awaitable[
|
Awaitable[
|
||||||
Iterable[
|
Iterable[types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource]
|
||||||
types.TextContent
|
|
||||||
| types.ImageContent
|
|
||||||
| types.AudioContent
|
|
||||||
| types.EmbeddedResource
|
|
||||||
]
|
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
@@ -418,9 +394,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
async def handler(req: types.CallToolRequest):
|
async def handler(req: types.CallToolRequest):
|
||||||
try:
|
try:
|
||||||
results = await func(req.params.name, (req.params.arguments or {}))
|
results = await func(req.params.name, (req.params.arguments or {}))
|
||||||
return types.ServerResult(
|
return types.ServerResult(types.CallToolResult(content=list(results), isError=False))
|
||||||
types.CallToolResult(content=list(results), isError=False)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return types.ServerResult(
|
return types.ServerResult(
|
||||||
types.CallToolResult(
|
types.CallToolResult(
|
||||||
@@ -436,9 +410,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
|
|
||||||
def progress_notification(self):
|
def progress_notification(self):
|
||||||
def decorator(
|
def decorator(
|
||||||
func: Callable[
|
func: Callable[[str | int, float, float | None, str | None], Awaitable[None]],
|
||||||
[str | int, float, float | None, str | None], Awaitable[None]
|
|
||||||
],
|
|
||||||
):
|
):
|
||||||
logger.debug("Registering handler for ProgressNotification")
|
logger.debug("Registering handler for ProgressNotification")
|
||||||
|
|
||||||
@@ -525,9 +497,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
|
|
||||||
async def _handle_message(
|
async def _handle_message(
|
||||||
self,
|
self,
|
||||||
message: RequestResponder[types.ClientRequest, types.ServerResult]
|
message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception,
|
||||||
| types.ClientNotification
|
|
||||||
| Exception,
|
|
||||||
session: ServerSession,
|
session: ServerSession,
|
||||||
lifespan_context: LifespanResultT,
|
lifespan_context: LifespanResultT,
|
||||||
raise_exceptions: bool = False,
|
raise_exceptions: bool = False,
|
||||||
@@ -535,20 +505,14 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
# TODO(Marcelo): We should be checking if message is Exception here.
|
# TODO(Marcelo): We should be checking if message is Exception here.
|
||||||
match message: # type: ignore[reportMatchNotExhaustive]
|
match message: # type: ignore[reportMatchNotExhaustive]
|
||||||
case (
|
case RequestResponder(request=types.ClientRequest(root=req)) as responder:
|
||||||
RequestResponder(request=types.ClientRequest(root=req)) as responder
|
|
||||||
):
|
|
||||||
with responder:
|
with responder:
|
||||||
await self._handle_request(
|
await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
|
||||||
message, req, session, lifespan_context, raise_exceptions
|
|
||||||
)
|
|
||||||
case types.ClientNotification(root=notify):
|
case types.ClientNotification(root=notify):
|
||||||
await self._handle_notification(notify)
|
await self._handle_notification(notify)
|
||||||
|
|
||||||
for warning in w:
|
for warning in w:
|
||||||
logger.info(
|
logger.info("Warning: %s: %s", warning.category.__name__, warning.message)
|
||||||
"Warning: %s: %s", warning.category.__name__, warning.message
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_request(
|
async def _handle_request(
|
||||||
self,
|
self,
|
||||||
@@ -566,9 +530,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
try:
|
try:
|
||||||
# Extract request context from message metadata
|
# Extract request context from message metadata
|
||||||
request_data = None
|
request_data = None
|
||||||
if message.message_metadata is not None and isinstance(
|
if message.message_metadata is not None and isinstance(message.message_metadata, ServerMessageMetadata):
|
||||||
message.message_metadata, ServerMessageMetadata
|
|
||||||
):
|
|
||||||
request_data = message.message_metadata.request_context
|
request_data = message.message_metadata.request_context
|
||||||
|
|
||||||
# Set our global state that can be retrieved via
|
# Set our global state that can be retrieved via
|
||||||
|
|||||||
@@ -64,9 +64,7 @@ class InitializationState(Enum):
|
|||||||
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
|
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
|
||||||
|
|
||||||
ServerRequestResponder = (
|
ServerRequestResponder = (
|
||||||
RequestResponder[types.ClientRequest, types.ServerResult]
|
RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception
|
||||||
| types.ClientNotification
|
|
||||||
| Exception
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -89,22 +87,16 @@ class ServerSession(
|
|||||||
init_options: InitializationOptions,
|
init_options: InitializationOptions,
|
||||||
stateless: bool = False,
|
stateless: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification)
|
||||||
read_stream, write_stream, types.ClientRequest, types.ClientNotification
|
|
||||||
)
|
|
||||||
self._initialization_state = (
|
self._initialization_state = (
|
||||||
InitializationState.Initialized
|
InitializationState.Initialized if stateless else InitializationState.NotInitialized
|
||||||
if stateless
|
|
||||||
else InitializationState.NotInitialized
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._init_options = init_options
|
self._init_options = init_options
|
||||||
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
|
self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[
|
||||||
anyio.create_memory_object_stream[ServerRequestResponder](0)
|
ServerRequestResponder
|
||||||
)
|
](0)
|
||||||
self._exit_stack.push_async_callback(
|
self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose())
|
||||||
lambda: self._incoming_message_stream_reader.aclose()
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client_params(self) -> types.InitializeRequestParams | None:
|
def client_params(self) -> types.InitializeRequestParams | None:
|
||||||
@@ -134,10 +126,7 @@ class ServerSession(
|
|||||||
return False
|
return False
|
||||||
# Check each experimental capability
|
# Check each experimental capability
|
||||||
for exp_key, exp_value in capability.experimental.items():
|
for exp_key, exp_value in capability.experimental.items():
|
||||||
if (
|
if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value:
|
||||||
exp_key not in client_caps.experimental
|
|
||||||
or client_caps.experimental[exp_key] != exp_value
|
|
||||||
):
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@@ -146,9 +135,7 @@ class ServerSession(
|
|||||||
async with self._incoming_message_stream_writer:
|
async with self._incoming_message_stream_writer:
|
||||||
await super()._receive_loop()
|
await super()._receive_loop()
|
||||||
|
|
||||||
async def _received_request(
|
async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]):
|
||||||
self, responder: RequestResponder[types.ClientRequest, types.ServerResult]
|
|
||||||
):
|
|
||||||
match responder.request.root:
|
match responder.request.root:
|
||||||
case types.InitializeRequest(params=params):
|
case types.InitializeRequest(params=params):
|
||||||
requested_version = params.protocolVersion
|
requested_version = params.protocolVersion
|
||||||
@@ -172,13 +159,9 @@ class ServerSession(
|
|||||||
)
|
)
|
||||||
case _:
|
case _:
|
||||||
if self._initialization_state != InitializationState.Initialized:
|
if self._initialization_state != InitializationState.Initialized:
|
||||||
raise RuntimeError(
|
raise RuntimeError("Received request before initialization was complete")
|
||||||
"Received request before initialization was complete"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _received_notification(
|
async def _received_notification(self, notification: types.ClientNotification) -> None:
|
||||||
self, notification: types.ClientNotification
|
|
||||||
) -> None:
|
|
||||||
# Need this to avoid ASYNC910
|
# Need this to avoid ASYNC910
|
||||||
await anyio.lowlevel.checkpoint()
|
await anyio.lowlevel.checkpoint()
|
||||||
match notification.root:
|
match notification.root:
|
||||||
@@ -186,9 +169,7 @@ class ServerSession(
|
|||||||
self._initialization_state = InitializationState.Initialized
|
self._initialization_state = InitializationState.Initialized
|
||||||
case _:
|
case _:
|
||||||
if self._initialization_state != InitializationState.Initialized:
|
if self._initialization_state != InitializationState.Initialized:
|
||||||
raise RuntimeError(
|
raise RuntimeError("Received notification before initialization was complete")
|
||||||
"Received notification before initialization was complete"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_log_message(
|
async def send_log_message(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -116,20 +116,14 @@ class SseServerTransport:
|
|||||||
full_message_path_for_client = root_path.rstrip("/") + self._endpoint
|
full_message_path_for_client = root_path.rstrip("/") + self._endpoint
|
||||||
|
|
||||||
# This is the URI (path + query) the client will use to POST messages.
|
# This is the URI (path + query) the client will use to POST messages.
|
||||||
client_post_uri_data = (
|
client_post_uri_data = f"{quote(full_message_path_for_client)}?session_id={session_id.hex}"
|
||||||
f"{quote(full_message_path_for_client)}?session_id={session_id.hex}"
|
|
||||||
)
|
|
||||||
|
|
||||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, Any]](0)
|
||||||
dict[str, Any]
|
|
||||||
](0)
|
|
||||||
|
|
||||||
async def sse_writer():
|
async def sse_writer():
|
||||||
logger.debug("Starting SSE writer")
|
logger.debug("Starting SSE writer")
|
||||||
async with sse_stream_writer, write_stream_reader:
|
async with sse_stream_writer, write_stream_reader:
|
||||||
await sse_stream_writer.send(
|
await sse_stream_writer.send({"event": "endpoint", "data": client_post_uri_data})
|
||||||
{"event": "endpoint", "data": client_post_uri_data}
|
|
||||||
)
|
|
||||||
logger.debug(f"Sent endpoint event: {client_post_uri_data}")
|
logger.debug(f"Sent endpoint event: {client_post_uri_data}")
|
||||||
|
|
||||||
async for session_message in write_stream_reader:
|
async for session_message in write_stream_reader:
|
||||||
@@ -137,9 +131,7 @@ class SseServerTransport:
|
|||||||
await sse_stream_writer.send(
|
await sse_stream_writer.send(
|
||||||
{
|
{
|
||||||
"event": "message",
|
"event": "message",
|
||||||
"data": session_message.message.model_dump_json(
|
"data": session_message.message.model_dump_json(by_alias=True, exclude_none=True),
|
||||||
by_alias=True, exclude_none=True
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -151,9 +143,9 @@ class SseServerTransport:
|
|||||||
In this case we close our side of the streams to signal the client that
|
In this case we close our side of the streams to signal the client that
|
||||||
the connection has been closed.
|
the connection has been closed.
|
||||||
"""
|
"""
|
||||||
await EventSourceResponse(
|
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
|
||||||
content=sse_stream_reader, data_sender_callable=sse_writer
|
scope, receive, send
|
||||||
)(scope, receive, send)
|
)
|
||||||
await read_stream_writer.aclose()
|
await read_stream_writer.aclose()
|
||||||
await write_stream_reader.aclose()
|
await write_stream_reader.aclose()
|
||||||
logging.debug(f"Client session disconnected {session_id}")
|
logging.debug(f"Client session disconnected {session_id}")
|
||||||
@@ -164,9 +156,7 @@ class SseServerTransport:
|
|||||||
logger.debug("Yielding read and write streams")
|
logger.debug("Yielding read and write streams")
|
||||||
yield (read_stream, write_stream)
|
yield (read_stream, write_stream)
|
||||||
|
|
||||||
async def handle_post_message(
|
async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
self, scope: Scope, receive: Receive, send: Send
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Handling POST message")
|
logger.debug("Handling POST message")
|
||||||
request = Request(scope, receive)
|
request = Request(scope, receive)
|
||||||
|
|
||||||
|
|||||||
@@ -76,9 +76,7 @@ async def stdio_server(
|
|||||||
try:
|
try:
|
||||||
async with write_stream_reader:
|
async with write_stream_reader:
|
||||||
async for session_message in write_stream_reader:
|
async for session_message in write_stream_reader:
|
||||||
json = session_message.message.model_dump_json(
|
json = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
|
||||||
by_alias=True, exclude_none=True
|
|
||||||
)
|
|
||||||
await stdout.write(json + "\n")
|
await stdout.write(json + "\n")
|
||||||
await stdout.flush()
|
await stdout.flush()
|
||||||
except anyio.ClosedResourceError:
|
except anyio.ClosedResourceError:
|
||||||
|
|||||||
@@ -82,9 +82,7 @@ class EventStore(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def store_event(
|
async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId:
|
||||||
self, stream_id: StreamId, message: JSONRPCMessage
|
|
||||||
) -> EventId:
|
|
||||||
"""
|
"""
|
||||||
Stores an event for later retrieval.
|
Stores an event for later retrieval.
|
||||||
|
|
||||||
@@ -125,9 +123,7 @@ class StreamableHTTPServerTransport:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Server notification streams for POST requests as well as standalone SSE stream
|
# Server notification streams for POST requests as well as standalone SSE stream
|
||||||
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = (
|
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = None
|
||||||
None
|
|
||||||
)
|
|
||||||
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
|
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
|
||||||
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
|
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
|
||||||
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
|
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
|
||||||
@@ -153,12 +149,8 @@ class StreamableHTTPServerTransport:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the session ID contains invalid characters.
|
ValueError: If the session ID contains invalid characters.
|
||||||
"""
|
"""
|
||||||
if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch(
|
if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch(mcp_session_id):
|
||||||
mcp_session_id
|
raise ValueError("Session ID must only contain visible ASCII characters (0x21-0x7E)")
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Session ID must only contain visible ASCII characters (0x21-0x7E)"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.mcp_session_id = mcp_session_id
|
self.mcp_session_id = mcp_session_id
|
||||||
self.is_json_response_enabled = is_json_response_enabled
|
self.is_json_response_enabled = is_json_response_enabled
|
||||||
@@ -218,9 +210,7 @@ class StreamableHTTPServerTransport:
|
|||||||
response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
response_message.model_dump_json(by_alias=True, exclude_none=True)
|
response_message.model_dump_json(by_alias=True, exclude_none=True) if response_message else None,
|
||||||
if response_message
|
|
||||||
else None,
|
|
||||||
status_code=status_code,
|
status_code=status_code,
|
||||||
headers=response_headers,
|
headers=response_headers,
|
||||||
)
|
)
|
||||||
@@ -233,9 +223,7 @@ class StreamableHTTPServerTransport:
|
|||||||
"""Create event data dictionary from an EventMessage."""
|
"""Create event data dictionary from an EventMessage."""
|
||||||
event_data = {
|
event_data = {
|
||||||
"event": "message",
|
"event": "message",
|
||||||
"data": event_message.message.model_dump_json(
|
"data": event_message.message.model_dump_json(by_alias=True, exclude_none=True),
|
||||||
by_alias=True, exclude_none=True
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# If an event ID was provided, include it
|
# If an event ID was provided, include it
|
||||||
@@ -283,42 +271,29 @@ class StreamableHTTPServerTransport:
|
|||||||
accept_header = request.headers.get("accept", "")
|
accept_header = request.headers.get("accept", "")
|
||||||
accept_types = [media_type.strip() for media_type in accept_header.split(",")]
|
accept_types = [media_type.strip() for media_type in accept_header.split(",")]
|
||||||
|
|
||||||
has_json = any(
|
has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types)
|
||||||
media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types
|
has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types)
|
||||||
)
|
|
||||||
has_sse = any(
|
|
||||||
media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types
|
|
||||||
)
|
|
||||||
|
|
||||||
return has_json, has_sse
|
return has_json, has_sse
|
||||||
|
|
||||||
def _check_content_type(self, request: Request) -> bool:
|
def _check_content_type(self, request: Request) -> bool:
|
||||||
"""Check if the request has the correct Content-Type."""
|
"""Check if the request has the correct Content-Type."""
|
||||||
content_type = request.headers.get("content-type", "")
|
content_type = request.headers.get("content-type", "")
|
||||||
content_type_parts = [
|
content_type_parts = [part.strip() for part in content_type.split(";")[0].split(",")]
|
||||||
part.strip() for part in content_type.split(";")[0].split(",")
|
|
||||||
]
|
|
||||||
|
|
||||||
return any(part == CONTENT_TYPE_JSON for part in content_type_parts)
|
return any(part == CONTENT_TYPE_JSON for part in content_type_parts)
|
||||||
|
|
||||||
async def _handle_post_request(
|
async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None:
|
||||||
self, scope: Scope, request: Request, receive: Receive, send: Send
|
|
||||||
) -> None:
|
|
||||||
"""Handle POST requests containing JSON-RPC messages."""
|
"""Handle POST requests containing JSON-RPC messages."""
|
||||||
writer = self._read_stream_writer
|
writer = self._read_stream_writer
|
||||||
if writer is None:
|
if writer is None:
|
||||||
raise ValueError(
|
raise ValueError("No read stream writer available. Ensure connect() is called first.")
|
||||||
"No read stream writer available. Ensure connect() is called first."
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
# Check Accept headers
|
# Check Accept headers
|
||||||
has_json, has_sse = self._check_accept_headers(request)
|
has_json, has_sse = self._check_accept_headers(request)
|
||||||
if not (has_json and has_sse):
|
if not (has_json and has_sse):
|
||||||
response = self._create_error_response(
|
response = self._create_error_response(
|
||||||
(
|
("Not Acceptable: Client must accept both application/json and " "text/event-stream"),
|
||||||
"Not Acceptable: Client must accept both application/json and "
|
|
||||||
"text/event-stream"
|
|
||||||
),
|
|
||||||
HTTPStatus.NOT_ACCEPTABLE,
|
HTTPStatus.NOT_ACCEPTABLE,
|
||||||
)
|
)
|
||||||
await response(scope, receive, send)
|
await response(scope, receive, send)
|
||||||
@@ -346,9 +321,7 @@ class StreamableHTTPServerTransport:
|
|||||||
try:
|
try:
|
||||||
raw_message = json.loads(body)
|
raw_message = json.loads(body)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
response = self._create_error_response(
|
response = self._create_error_response(f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR)
|
||||||
f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR
|
|
||||||
)
|
|
||||||
await response(scope, receive, send)
|
await response(scope, receive, send)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -364,10 +337,7 @@ class StreamableHTTPServerTransport:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Check if this is an initialization request
|
# Check if this is an initialization request
|
||||||
is_initialization_request = (
|
is_initialization_request = isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
||||||
isinstance(message.root, JSONRPCRequest)
|
|
||||||
and message.root.method == "initialize"
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_initialization_request:
|
if is_initialization_request:
|
||||||
# Check if the server already has an established session
|
# Check if the server already has an established session
|
||||||
@@ -406,9 +376,7 @@ class StreamableHTTPServerTransport:
|
|||||||
# Extract the request ID outside the try block for proper scope
|
# Extract the request ID outside the try block for proper scope
|
||||||
request_id = str(message.root.id)
|
request_id = str(message.root.id)
|
||||||
# Register this stream for the request ID
|
# Register this stream for the request ID
|
||||||
self._request_streams[request_id] = anyio.create_memory_object_stream[
|
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0)
|
||||||
EventMessage
|
|
||||||
](0)
|
|
||||||
request_stream_reader = self._request_streams[request_id][1]
|
request_stream_reader = self._request_streams[request_id][1]
|
||||||
|
|
||||||
if self.is_json_response_enabled:
|
if self.is_json_response_enabled:
|
||||||
@@ -424,16 +392,12 @@ class StreamableHTTPServerTransport:
|
|||||||
# Use similar approach to SSE writer for consistency
|
# Use similar approach to SSE writer for consistency
|
||||||
async for event_message in request_stream_reader:
|
async for event_message in request_stream_reader:
|
||||||
# If it's a response, this is what we're waiting for
|
# If it's a response, this is what we're waiting for
|
||||||
if isinstance(
|
if isinstance(event_message.message.root, JSONRPCResponse | JSONRPCError):
|
||||||
event_message.message.root, JSONRPCResponse | JSONRPCError
|
|
||||||
):
|
|
||||||
response_message = event_message.message
|
response_message = event_message.message
|
||||||
break
|
break
|
||||||
# For notifications and request, keep waiting
|
# For notifications and request, keep waiting
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(f"received: {event_message.message.root.method}")
|
||||||
f"received: {event_message.message.root.method}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# At this point we should have a response
|
# At this point we should have a response
|
||||||
if response_message:
|
if response_message:
|
||||||
@@ -442,9 +406,7 @@ class StreamableHTTPServerTransport:
|
|||||||
await response(scope, receive, send)
|
await response(scope, receive, send)
|
||||||
else:
|
else:
|
||||||
# This shouldn't happen in normal operation
|
# This shouldn't happen in normal operation
|
||||||
logger.error(
|
logger.error("No response message received before stream closed")
|
||||||
"No response message received before stream closed"
|
|
||||||
)
|
|
||||||
response = self._create_error_response(
|
response = self._create_error_response(
|
||||||
"Error processing request: No response received",
|
"Error processing request: No response received",
|
||||||
HTTPStatus.INTERNAL_SERVER_ERROR,
|
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||||
@@ -462,9 +424,7 @@ class StreamableHTTPServerTransport:
|
|||||||
await self._clean_up_memory_streams(request_id)
|
await self._clean_up_memory_streams(request_id)
|
||||||
else:
|
else:
|
||||||
# Create SSE stream
|
# Create SSE stream
|
||||||
sse_stream_writer, sse_stream_reader = (
|
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
|
||||||
anyio.create_memory_object_stream[dict[str, str]](0)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def sse_writer():
|
async def sse_writer():
|
||||||
# Get the request ID from the incoming request message
|
# Get the request ID from the incoming request message
|
||||||
@@ -495,11 +455,7 @@ class StreamableHTTPServerTransport:
|
|||||||
"Cache-Control": "no-cache, no-transform",
|
"Cache-Control": "no-cache, no-transform",
|
||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
"Content-Type": CONTENT_TYPE_SSE,
|
"Content-Type": CONTENT_TYPE_SSE,
|
||||||
**(
|
**({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}),
|
||||||
{MCP_SESSION_ID_HEADER: self.mcp_session_id}
|
|
||||||
if self.mcp_session_id
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
response = EventSourceResponse(
|
response = EventSourceResponse(
|
||||||
content=sse_stream_reader,
|
content=sse_stream_reader,
|
||||||
@@ -544,9 +500,7 @@ class StreamableHTTPServerTransport:
|
|||||||
"""
|
"""
|
||||||
writer = self._read_stream_writer
|
writer = self._read_stream_writer
|
||||||
if writer is None:
|
if writer is None:
|
||||||
raise ValueError(
|
raise ValueError("No read stream writer available. Ensure connect() is called first.")
|
||||||
"No read stream writer available. Ensure connect() is called first."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate Accept header - must include text/event-stream
|
# Validate Accept header - must include text/event-stream
|
||||||
_, has_sse = self._check_accept_headers(request)
|
_, has_sse = self._check_accept_headers(request)
|
||||||
@@ -585,17 +539,13 @@ class StreamableHTTPServerTransport:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Create SSE stream
|
# Create SSE stream
|
||||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
|
||||||
dict[str, str]
|
|
||||||
](0)
|
|
||||||
|
|
||||||
async def standalone_sse_writer():
|
async def standalone_sse_writer():
|
||||||
try:
|
try:
|
||||||
# Create a standalone message stream for server-initiated messages
|
# Create a standalone message stream for server-initiated messages
|
||||||
|
|
||||||
self._request_streams[GET_STREAM_KEY] = (
|
self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](0)
|
||||||
anyio.create_memory_object_stream[EventMessage](0)
|
|
||||||
)
|
|
||||||
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]
|
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]
|
||||||
|
|
||||||
async with sse_stream_writer, standalone_stream_reader:
|
async with sse_stream_writer, standalone_stream_reader:
|
||||||
@@ -732,9 +682,7 @@ class StreamableHTTPServerTransport:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def _replay_events(
|
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
|
||||||
self, last_event_id: str, request: Request, send: Send
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Replays events that would have been sent after the specified event ID.
|
Replays events that would have been sent after the specified event ID.
|
||||||
Only used when resumability is enabled.
|
Only used when resumability is enabled.
|
||||||
@@ -754,9 +702,7 @@ class StreamableHTTPServerTransport:
|
|||||||
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
|
||||||
|
|
||||||
# Create SSE stream for replay
|
# Create SSE stream for replay
|
||||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
|
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
|
||||||
dict[str, str]
|
|
||||||
](0)
|
|
||||||
|
|
||||||
async def replay_sender():
|
async def replay_sender():
|
||||||
try:
|
try:
|
||||||
@@ -767,15 +713,11 @@ class StreamableHTTPServerTransport:
|
|||||||
await sse_stream_writer.send(event_data)
|
await sse_stream_writer.send(event_data)
|
||||||
|
|
||||||
# Replay past events and get the stream ID
|
# Replay past events and get the stream ID
|
||||||
stream_id = await event_store.replay_events_after(
|
stream_id = await event_store.replay_events_after(last_event_id, send_event)
|
||||||
last_event_id, send_event
|
|
||||||
)
|
|
||||||
|
|
||||||
# If stream ID not in mapping, create it
|
# If stream ID not in mapping, create it
|
||||||
if stream_id and stream_id not in self._request_streams:
|
if stream_id and stream_id not in self._request_streams:
|
||||||
self._request_streams[stream_id] = (
|
self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0)
|
||||||
anyio.create_memory_object_stream[EventMessage](0)
|
|
||||||
)
|
|
||||||
msg_reader = self._request_streams[stream_id][1]
|
msg_reader = self._request_streams[stream_id][1]
|
||||||
|
|
||||||
# Forward messages to SSE
|
# Forward messages to SSE
|
||||||
@@ -829,12 +771,8 @@ class StreamableHTTPServerTransport:
|
|||||||
|
|
||||||
# Create the memory streams for this connection
|
# Create the memory streams for this connection
|
||||||
|
|
||||||
read_stream_writer, read_stream = anyio.create_memory_object_stream[
|
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
|
||||||
SessionMessage | Exception
|
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
|
||||||
](0)
|
|
||||||
write_stream, write_stream_reader = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage
|
|
||||||
](0)
|
|
||||||
|
|
||||||
# Store the streams
|
# Store the streams
|
||||||
self._read_stream_writer = read_stream_writer
|
self._read_stream_writer = read_stream_writer
|
||||||
@@ -867,35 +805,24 @@ class StreamableHTTPServerTransport:
|
|||||||
session_message.metadata,
|
session_message.metadata,
|
||||||
ServerMessageMetadata,
|
ServerMessageMetadata,
|
||||||
)
|
)
|
||||||
and session_message.metadata.related_request_id
|
and session_message.metadata.related_request_id is not None
|
||||||
is not None
|
|
||||||
):
|
):
|
||||||
target_request_id = str(
|
target_request_id = str(session_message.metadata.related_request_id)
|
||||||
session_message.metadata.related_request_id
|
|
||||||
)
|
|
||||||
|
|
||||||
request_stream_id = (
|
request_stream_id = target_request_id if target_request_id is not None else GET_STREAM_KEY
|
||||||
target_request_id
|
|
||||||
if target_request_id is not None
|
|
||||||
else GET_STREAM_KEY
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store the event if we have an event store,
|
# Store the event if we have an event store,
|
||||||
# regardless of whether a client is connected
|
# regardless of whether a client is connected
|
||||||
# messages will be replayed on the re-connect
|
# messages will be replayed on the re-connect
|
||||||
event_id = None
|
event_id = None
|
||||||
if self._event_store:
|
if self._event_store:
|
||||||
event_id = await self._event_store.store_event(
|
event_id = await self._event_store.store_event(request_stream_id, message)
|
||||||
request_stream_id, message
|
|
||||||
)
|
|
||||||
logger.debug(f"Stored {event_id} from {request_stream_id}")
|
logger.debug(f"Stored {event_id} from {request_stream_id}")
|
||||||
|
|
||||||
if request_stream_id in self._request_streams:
|
if request_stream_id in self._request_streams:
|
||||||
try:
|
try:
|
||||||
# Send both the message and the event ID
|
# Send both the message and the event ID
|
||||||
await self._request_streams[request_stream_id][0].send(
|
await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id))
|
||||||
EventMessage(message, event_id)
|
|
||||||
)
|
|
||||||
except (
|
except (
|
||||||
anyio.BrokenResourceError,
|
anyio.BrokenResourceError,
|
||||||
anyio.ClosedResourceError,
|
anyio.ClosedResourceError,
|
||||||
|
|||||||
@@ -165,9 +165,7 @@ class StreamableHTTPSessionManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Start server in a new task
|
# Start server in a new task
|
||||||
async def run_stateless_server(
|
async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED):
|
||||||
*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED
|
|
||||||
):
|
|
||||||
async with http_transport.connect() as streams:
|
async with http_transport.connect() as streams:
|
||||||
read_stream, write_stream = streams
|
read_stream, write_stream = streams
|
||||||
task_status.started()
|
task_status.started()
|
||||||
@@ -204,10 +202,7 @@ class StreamableHTTPSessionManager:
|
|||||||
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
|
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
|
||||||
|
|
||||||
# Existing session case
|
# Existing session case
|
||||||
if (
|
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances:
|
||||||
request_mcp_session_id is not None
|
|
||||||
and request_mcp_session_id in self._server_instances
|
|
||||||
):
|
|
||||||
transport = self._server_instances[request_mcp_session_id]
|
transport = self._server_instances[request_mcp_session_id]
|
||||||
logger.debug("Session already exists, handling request directly")
|
logger.debug("Session already exists, handling request directly")
|
||||||
await transport.handle_request(scope, receive, send)
|
await transport.handle_request(scope, receive, send)
|
||||||
@@ -229,9 +224,7 @@ class StreamableHTTPSessionManager:
|
|||||||
logger.info(f"Created new transport with session ID: {new_session_id}")
|
logger.info(f"Created new transport with session ID: {new_session_id}")
|
||||||
|
|
||||||
# Define the server runner
|
# Define the server runner
|
||||||
async def run_server(
|
async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None:
|
||||||
*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED
|
|
||||||
) -> None:
|
|
||||||
async with http_transport.connect() as streams:
|
async with http_transport.connect() as streams:
|
||||||
read_stream, write_stream = streams
|
read_stream, write_stream = streams
|
||||||
task_status.started()
|
task_status.started()
|
||||||
|
|||||||
@@ -93,12 +93,8 @@ class StreamingASGITransport(AsyncBaseTransport):
|
|||||||
initial_response_ready = anyio.Event()
|
initial_response_ready = anyio.Event()
|
||||||
|
|
||||||
# Synchronization for streaming response
|
# Synchronization for streaming response
|
||||||
asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[
|
asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[dict[str, Any]](100)
|
||||||
dict[str, Any]
|
content_send_channel, content_receive_channel = anyio.create_memory_object_stream[bytes](100)
|
||||||
](100)
|
|
||||||
content_send_channel, content_receive_channel = (
|
|
||||||
anyio.create_memory_object_stream[bytes](100)
|
|
||||||
)
|
|
||||||
|
|
||||||
# ASGI callables.
|
# ASGI callables.
|
||||||
async def receive() -> dict[str, Any]:
|
async def receive() -> dict[str, Any]:
|
||||||
@@ -124,21 +120,15 @@ class StreamingASGITransport(AsyncBaseTransport):
|
|||||||
async def run_app() -> None:
|
async def run_app() -> None:
|
||||||
try:
|
try:
|
||||||
# Cast the receive and send functions to the ASGI types
|
# Cast the receive and send functions to the ASGI types
|
||||||
await self.app(
|
await self.app(cast(Scope, scope), cast(Receive, receive), cast(Send, send))
|
||||||
cast(Scope, scope), cast(Receive, receive), cast(Send, send)
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
if self.raise_app_exceptions:
|
if self.raise_app_exceptions:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if not response_started:
|
if not response_started:
|
||||||
await asgi_send_channel.send(
|
await asgi_send_channel.send({"type": "http.response.start", "status": 500, "headers": []})
|
||||||
{"type": "http.response.start", "status": 500, "headers": []}
|
|
||||||
)
|
|
||||||
|
|
||||||
await asgi_send_channel.send(
|
await asgi_send_channel.send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||||
{"type": "http.response.body", "body": b"", "more_body": False}
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
await asgi_send_channel.aclose()
|
await asgi_send_channel.aclose()
|
||||||
|
|
||||||
|
|||||||
@@ -51,9 +51,7 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
|
|||||||
try:
|
try:
|
||||||
async with write_stream_reader:
|
async with write_stream_reader:
|
||||||
async for session_message in write_stream_reader:
|
async for session_message in write_stream_reader:
|
||||||
obj = session_message.message.model_dump_json(
|
obj = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
|
||||||
by_alias=True, exclude_none=True
|
|
||||||
)
|
|
||||||
await websocket.send_text(obj)
|
await websocket.send_text(obj)
|
||||||
except anyio.ClosedResourceError:
|
except anyio.ClosedResourceError:
|
||||||
await websocket.close()
|
await websocket.close()
|
||||||
|
|||||||
@@ -45,9 +45,7 @@ class OAuthClientMetadata(BaseModel):
|
|||||||
# token_endpoint_auth_method: this implementation only supports none &
|
# token_endpoint_auth_method: this implementation only supports none &
|
||||||
# client_secret_post;
|
# client_secret_post;
|
||||||
# ie: we do not support client_secret_basic
|
# ie: we do not support client_secret_basic
|
||||||
token_endpoint_auth_method: Literal["none", "client_secret_post"] = (
|
token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post"
|
||||||
"client_secret_post"
|
|
||||||
)
|
|
||||||
# grant_types: this implementation only supports authorization_code & refresh_token
|
# grant_types: this implementation only supports authorization_code & refresh_token
|
||||||
grant_types: list[Literal["authorization_code", "refresh_token"]] = [
|
grant_types: list[Literal["authorization_code", "refresh_token"]] = [
|
||||||
"authorization_code",
|
"authorization_code",
|
||||||
@@ -84,17 +82,12 @@ class OAuthClientMetadata(BaseModel):
|
|||||||
if redirect_uri is not None:
|
if redirect_uri is not None:
|
||||||
# Validate redirect_uri against client's registered redirect URIs
|
# Validate redirect_uri against client's registered redirect URIs
|
||||||
if redirect_uri not in self.redirect_uris:
|
if redirect_uri not in self.redirect_uris:
|
||||||
raise InvalidRedirectUriError(
|
raise InvalidRedirectUriError(f"Redirect URI '{redirect_uri}' not registered for client")
|
||||||
f"Redirect URI '{redirect_uri}' not registered for client"
|
|
||||||
)
|
|
||||||
return redirect_uri
|
return redirect_uri
|
||||||
elif len(self.redirect_uris) == 1:
|
elif len(self.redirect_uris) == 1:
|
||||||
return self.redirect_uris[0]
|
return self.redirect_uris[0]
|
||||||
else:
|
else:
|
||||||
raise InvalidRedirectUriError(
|
raise InvalidRedirectUriError("redirect_uri must be specified when client " "has multiple registered URIs")
|
||||||
"redirect_uri must be specified when client "
|
|
||||||
"has multiple registered URIs"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthClientInformationFull(OAuthClientMetadata):
|
class OAuthClientInformationFull(OAuthClientMetadata):
|
||||||
|
|||||||
@@ -11,26 +11,15 @@ import anyio
|
|||||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
|
|
||||||
import mcp.types as types
|
import mcp.types as types
|
||||||
from mcp.client.session import (
|
from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
|
||||||
ClientSession,
|
|
||||||
ListRootsFnT,
|
|
||||||
LoggingFnT,
|
|
||||||
MessageHandlerFnT,
|
|
||||||
SamplingFnT,
|
|
||||||
)
|
|
||||||
from mcp.server import Server
|
from mcp.server import Server
|
||||||
from mcp.shared.message import SessionMessage
|
from mcp.shared.message import SessionMessage
|
||||||
|
|
||||||
MessageStream = tuple[
|
MessageStream = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
|
||||||
MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
||||||
MemoryObjectSendStream[SessionMessage],
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def create_client_server_memory_streams() -> (
|
async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageStream, MessageStream], None]:
|
||||||
AsyncGenerator[tuple[MessageStream, MessageStream], None]
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Creates a pair of bidirectional memory streams for client-server communication.
|
Creates a pair of bidirectional memory streams for client-server communication.
|
||||||
|
|
||||||
@@ -39,12 +28,8 @@ async def create_client_server_memory_streams() -> (
|
|||||||
(read_stream, write_stream)
|
(read_stream, write_stream)
|
||||||
"""
|
"""
|
||||||
# Create streams for both directions
|
# Create streams for both directions
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
|
||||||
SessionMessage | Exception
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
|
||||||
](1)
|
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage | Exception
|
|
||||||
](1)
|
|
||||||
|
|
||||||
client_streams = (server_to_client_receive, client_to_server_send)
|
client_streams = (server_to_client_receive, client_to_server_send)
|
||||||
server_streams = (client_to_server_receive, server_to_client_send)
|
server_streams = (client_to_server_receive, server_to_client_send)
|
||||||
|
|||||||
@@ -20,9 +20,7 @@ class ClientMessageMetadata:
|
|||||||
"""Metadata specific to client messages."""
|
"""Metadata specific to client messages."""
|
||||||
|
|
||||||
resumption_token: ResumptionToken | None = None
|
resumption_token: ResumptionToken | None = None
|
||||||
on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = (
|
on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = None
|
||||||
None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -23,22 +23,8 @@ class Progress(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProgressContext(
|
class ProgressContext(Generic[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]):
|
||||||
Generic[
|
session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]
|
||||||
SendRequestT,
|
|
||||||
SendNotificationT,
|
|
||||||
SendResultT,
|
|
||||||
ReceiveRequestT,
|
|
||||||
ReceiveNotificationT,
|
|
||||||
]
|
|
||||||
):
|
|
||||||
session: BaseSession[
|
|
||||||
SendRequestT,
|
|
||||||
SendNotificationT,
|
|
||||||
SendResultT,
|
|
||||||
ReceiveRequestT,
|
|
||||||
ReceiveNotificationT,
|
|
||||||
]
|
|
||||||
progress_token: ProgressToken
|
progress_token: ProgressToken
|
||||||
total: float | None
|
total: float | None
|
||||||
current: float = field(default=0.0, init=False)
|
current: float = field(default=0.0, init=False)
|
||||||
@@ -54,24 +40,12 @@ class ProgressContext(
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def progress(
|
def progress(
|
||||||
ctx: RequestContext[
|
ctx: RequestContext[
|
||||||
BaseSession[
|
BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
|
||||||
SendRequestT,
|
|
||||||
SendNotificationT,
|
|
||||||
SendResultT,
|
|
||||||
ReceiveRequestT,
|
|
||||||
ReceiveNotificationT,
|
|
||||||
],
|
|
||||||
LifespanContextT,
|
LifespanContextT,
|
||||||
],
|
],
|
||||||
total: float | None = None,
|
total: float | None = None,
|
||||||
) -> Generator[
|
) -> Generator[
|
||||||
ProgressContext[
|
ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
|
||||||
SendRequestT,
|
|
||||||
SendNotificationT,
|
|
||||||
SendResultT,
|
|
||||||
ReceiveRequestT,
|
|
||||||
ReceiveNotificationT,
|
|
||||||
],
|
|
||||||
None,
|
None,
|
||||||
]:
|
]:
|
||||||
if ctx.meta is None or ctx.meta.progressToken is None:
|
if ctx.meta is None or ctx.meta.progressToken is None:
|
||||||
|
|||||||
@@ -38,9 +38,7 @@ SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
|||||||
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
||||||
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
||||||
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
||||||
ReceiveNotificationT = TypeVar(
|
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
|
||||||
"ReceiveNotificationT", ClientNotification, ServerNotification
|
|
||||||
)
|
|
||||||
|
|
||||||
RequestId = str | int
|
RequestId = str | int
|
||||||
|
|
||||||
@@ -48,9 +46,7 @@ RequestId = str | int
|
|||||||
class ProgressFnT(Protocol):
|
class ProgressFnT(Protocol):
|
||||||
"""Protocol for progress notification callbacks."""
|
"""Protocol for progress notification callbacks."""
|
||||||
|
|
||||||
async def __call__(
|
async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ...
|
||||||
self, progress: float, total: float | None, message: str | None
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||||
@@ -177,9 +173,7 @@ class BaseSession(
|
|||||||
messages when entered.
|
messages when entered.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_response_streams: dict[
|
_response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]]
|
||||||
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
|
|
||||||
]
|
|
||||||
_request_id: int
|
_request_id: int
|
||||||
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
||||||
_progress_callbacks: dict[RequestId, ProgressFnT]
|
_progress_callbacks: dict[RequestId, ProgressFnT]
|
||||||
@@ -242,9 +236,7 @@ class BaseSession(
|
|||||||
request_id = self._request_id
|
request_id = self._request_id
|
||||||
self._request_id = request_id + 1
|
self._request_id = request_id + 1
|
||||||
|
|
||||||
response_stream, response_stream_reader = anyio.create_memory_object_stream[
|
response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1)
|
||||||
JSONRPCResponse | JSONRPCError
|
|
||||||
](1)
|
|
||||||
self._response_streams[request_id] = response_stream
|
self._response_streams[request_id] = response_stream
|
||||||
|
|
||||||
# Set up progress token if progress callback is provided
|
# Set up progress token if progress callback is provided
|
||||||
@@ -266,11 +258,7 @@ class BaseSession(
|
|||||||
**request_data,
|
**request_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._write_stream.send(
|
await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
|
||||||
SessionMessage(
|
|
||||||
message=JSONRPCMessage(jsonrpc_request), metadata=metadata
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# request read timeout takes precedence over session read timeout
|
# request read timeout takes precedence over session read timeout
|
||||||
timeout = None
|
timeout = None
|
||||||
@@ -322,15 +310,11 @@ class BaseSession(
|
|||||||
)
|
)
|
||||||
session_message = SessionMessage(
|
session_message = SessionMessage(
|
||||||
message=JSONRPCMessage(jsonrpc_notification),
|
message=JSONRPCMessage(jsonrpc_notification),
|
||||||
metadata=ServerMessageMetadata(related_request_id=related_request_id)
|
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
|
||||||
if related_request_id
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
await self._write_stream.send(session_message)
|
await self._write_stream.send(session_message)
|
||||||
|
|
||||||
async def _send_response(
|
async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
|
||||||
self, request_id: RequestId, response: SendResultT | ErrorData
|
|
||||||
) -> None:
|
|
||||||
if isinstance(response, ErrorData):
|
if isinstance(response, ErrorData):
|
||||||
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
||||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
||||||
@@ -339,9 +323,7 @@ class BaseSession(
|
|||||||
jsonrpc_response = JSONRPCResponse(
|
jsonrpc_response = JSONRPCResponse(
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
id=request_id,
|
id=request_id,
|
||||||
result=response.model_dump(
|
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||||
by_alias=True, mode="json", exclude_none=True
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
|
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
|
||||||
await self._write_stream.send(session_message)
|
await self._write_stream.send(session_message)
|
||||||
@@ -357,19 +339,14 @@ class BaseSession(
|
|||||||
elif isinstance(message.message.root, JSONRPCRequest):
|
elif isinstance(message.message.root, JSONRPCRequest):
|
||||||
try:
|
try:
|
||||||
validated_request = self._receive_request_type.model_validate(
|
validated_request = self._receive_request_type.model_validate(
|
||||||
message.message.root.model_dump(
|
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
by_alias=True, mode="json", exclude_none=True
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
responder = RequestResponder(
|
responder = RequestResponder(
|
||||||
request_id=message.message.root.id,
|
request_id=message.message.root.id,
|
||||||
request_meta=validated_request.root.params.meta
|
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
|
||||||
if validated_request.root.params
|
|
||||||
else None,
|
|
||||||
request=validated_request,
|
request=validated_request,
|
||||||
session=self,
|
session=self,
|
||||||
on_complete=lambda r: self._in_flight.pop(
|
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||||
r.request_id, None),
|
|
||||||
message_metadata=message.metadata,
|
message_metadata=message.metadata,
|
||||||
)
|
)
|
||||||
self._in_flight[responder.request_id] = responder
|
self._in_flight[responder.request_id] = responder
|
||||||
@@ -381,9 +358,7 @@ class BaseSession(
|
|||||||
# For request validation errors, send a proper JSON-RPC error
|
# For request validation errors, send a proper JSON-RPC error
|
||||||
# response instead of crashing the server
|
# response instead of crashing the server
|
||||||
logging.warning(f"Failed to validate request: {e}")
|
logging.warning(f"Failed to validate request: {e}")
|
||||||
logging.debug(
|
logging.debug(f"Message that failed validation: {message.message.root}")
|
||||||
f"Message that failed validation: {message.message.root}"
|
|
||||||
)
|
|
||||||
error_response = JSONRPCError(
|
error_response = JSONRPCError(
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
id=message.message.root.id,
|
id=message.message.root.id,
|
||||||
@@ -393,16 +368,13 @@ class BaseSession(
|
|||||||
data="",
|
data="",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
session_message = SessionMessage(
|
session_message = SessionMessage(message=JSONRPCMessage(error_response))
|
||||||
message=JSONRPCMessage(error_response))
|
|
||||||
await self._write_stream.send(session_message)
|
await self._write_stream.send(session_message)
|
||||||
|
|
||||||
elif isinstance(message.message.root, JSONRPCNotification):
|
elif isinstance(message.message.root, JSONRPCNotification):
|
||||||
try:
|
try:
|
||||||
notification = self._receive_notification_type.model_validate(
|
notification = self._receive_notification_type.model_validate(
|
||||||
message.message.root.model_dump(
|
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
by_alias=True, mode="json", exclude_none=True
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# Handle cancellation notifications
|
# Handle cancellation notifications
|
||||||
if isinstance(notification.root, CancelledNotification):
|
if isinstance(notification.root, CancelledNotification):
|
||||||
@@ -427,8 +399,7 @@ class BaseSession(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# For other validation errors, log and continue
|
# For other validation errors, log and continue
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Failed to validate notification: {e}. "
|
f"Failed to validate notification: {e}. " f"Message was: {message.message.root}"
|
||||||
f"Message was: {message.message.root}"
|
|
||||||
)
|
)
|
||||||
else: # Response or error
|
else: # Response or error
|
||||||
stream = self._response_streams.pop(message.message.root.id, None)
|
stream = self._response_streams.pop(message.message.root.id, None)
|
||||||
@@ -436,10 +407,7 @@ class BaseSession(
|
|||||||
await stream.send(message.message.root)
|
await stream.send(message.message.root)
|
||||||
else:
|
else:
|
||||||
await self._handle_incoming(
|
await self._handle_incoming(
|
||||||
RuntimeError(
|
RuntimeError("Received response with an unknown " f"request ID: {message}")
|
||||||
"Received response with an unknown "
|
|
||||||
f"request ID: {message}"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# after the read stream is closed, we need to send errors
|
# after the read stream is closed, we need to send errors
|
||||||
@@ -450,9 +418,7 @@ class BaseSession(
|
|||||||
await stream.aclose()
|
await stream.aclose()
|
||||||
self._response_streams.clear()
|
self._response_streams.clear()
|
||||||
|
|
||||||
async def _received_request(
|
async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
|
||||||
self, responder: RequestResponder[ReceiveRequestT, SendResultT]
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Can be overridden by subclasses to handle a request without needing to
|
Can be overridden by subclasses to handle a request without needing to
|
||||||
listen on the message stream.
|
listen on the message stream.
|
||||||
@@ -481,9 +447,7 @@ class BaseSession(
|
|||||||
|
|
||||||
async def _handle_incoming(
|
async def _handle_incoming(
|
||||||
self,
|
self,
|
||||||
req: RequestResponder[ReceiveRequestT, SendResultT]
|
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
|
||||||
| ReceiveNotificationT
|
|
||||||
| Exception,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,12 +1,5 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import (
|
from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar
|
||||||
Annotated,
|
|
||||||
Any,
|
|
||||||
Generic,
|
|
||||||
Literal,
|
|
||||||
TypeAlias,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
|
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
|
||||||
from pydantic.networks import AnyUrl, UrlConstraints
|
from pydantic.networks import AnyUrl, UrlConstraints
|
||||||
@@ -73,9 +66,7 @@ class NotificationParams(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None)
|
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None)
|
||||||
NotificationParamsT = TypeVar(
|
NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams | dict[str, Any] | None)
|
||||||
"NotificationParamsT", bound=NotificationParams | dict[str, Any] | None
|
|
||||||
)
|
|
||||||
MethodT = TypeVar("MethodT", bound=str)
|
MethodT = TypeVar("MethodT", bound=str)
|
||||||
|
|
||||||
|
|
||||||
@@ -87,9 +78,7 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class PaginatedRequest(
|
class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]):
|
||||||
Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]
|
|
||||||
):
|
|
||||||
"""Base class for paginated requests,
|
"""Base class for paginated requests,
|
||||||
matching the schema's PaginatedRequest interface."""
|
matching the schema's PaginatedRequest interface."""
|
||||||
|
|
||||||
@@ -191,9 +180,7 @@ class JSONRPCError(BaseModel):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class JSONRPCMessage(
|
class JSONRPCMessage(RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]):
|
||||||
RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]
|
|
||||||
):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -314,9 +301,7 @@ class InitializeResult(Result):
|
|||||||
"""Instructions describing how to use the server and its features."""
|
"""Instructions describing how to use the server and its features."""
|
||||||
|
|
||||||
|
|
||||||
class InitializedNotification(
|
class InitializedNotification(Notification[NotificationParams | None, Literal["notifications/initialized"]]):
|
||||||
Notification[NotificationParams | None, Literal["notifications/initialized"]]
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
This notification is sent from the client to the server after initialization has
|
This notification is sent from the client to the server after initialization has
|
||||||
finished.
|
finished.
|
||||||
@@ -359,9 +344,7 @@ class ProgressNotificationParams(NotificationParams):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class ProgressNotification(
|
class ProgressNotification(Notification[ProgressNotificationParams, Literal["notifications/progress"]]):
|
||||||
Notification[ProgressNotificationParams, Literal["notifications/progress"]]
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
An out-of-band notification used to inform the receiver of a progress update for a
|
An out-of-band notification used to inform the receiver of a progress update for a
|
||||||
long-running request.
|
long-running request.
|
||||||
@@ -432,9 +415,7 @@ class ListResourcesResult(PaginatedResult):
|
|||||||
resources: list[Resource]
|
resources: list[Resource]
|
||||||
|
|
||||||
|
|
||||||
class ListResourceTemplatesRequest(
|
class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]):
|
||||||
PaginatedRequest[Literal["resources/templates/list"]]
|
|
||||||
):
|
|
||||||
"""Sent from the client to request a list of resource templates the server has."""
|
"""Sent from the client to request a list of resource templates the server has."""
|
||||||
|
|
||||||
method: Literal["resources/templates/list"]
|
method: Literal["resources/templates/list"]
|
||||||
@@ -457,9 +438,7 @@ class ReadResourceRequestParams(RequestParams):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class ReadResourceRequest(
|
class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]):
|
||||||
Request[ReadResourceRequestParams, Literal["resources/read"]]
|
|
||||||
):
|
|
||||||
"""Sent from the client to the server, to read a specific resource URI."""
|
"""Sent from the client to the server, to read a specific resource URI."""
|
||||||
|
|
||||||
method: Literal["resources/read"]
|
method: Literal["resources/read"]
|
||||||
@@ -500,9 +479,7 @@ class ReadResourceResult(Result):
|
|||||||
|
|
||||||
|
|
||||||
class ResourceListChangedNotification(
|
class ResourceListChangedNotification(
|
||||||
Notification[
|
Notification[NotificationParams | None, Literal["notifications/resources/list_changed"]]
|
||||||
NotificationParams | None, Literal["notifications/resources/list_changed"]
|
|
||||||
]
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
An optional notification from the server to the client, informing it that the list
|
An optional notification from the server to the client, informing it that the list
|
||||||
@@ -542,9 +519,7 @@ class UnsubscribeRequestParams(RequestParams):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class UnsubscribeRequest(
|
class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]]):
|
||||||
Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]]
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Sent from the client to request cancellation of resources/updated notifications from
|
Sent from the client to request cancellation of resources/updated notifications from
|
||||||
the server.
|
the server.
|
||||||
@@ -566,9 +541,7 @@ class ResourceUpdatedNotificationParams(NotificationParams):
|
|||||||
|
|
||||||
|
|
||||||
class ResourceUpdatedNotification(
|
class ResourceUpdatedNotification(
|
||||||
Notification[
|
Notification[ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"]]
|
||||||
ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"]
|
|
||||||
]
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
A notification from the server to the client, informing it that a resource has
|
A notification from the server to the client, informing it that a resource has
|
||||||
@@ -711,9 +684,7 @@ class GetPromptResult(Result):
|
|||||||
|
|
||||||
|
|
||||||
class PromptListChangedNotification(
|
class PromptListChangedNotification(
|
||||||
Notification[
|
Notification[NotificationParams | None, Literal["notifications/prompts/list_changed"]]
|
||||||
NotificationParams | None, Literal["notifications/prompts/list_changed"]
|
|
||||||
]
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
An optional notification from the server to the client, informing it that the list
|
An optional notification from the server to the client, informing it that the list
|
||||||
@@ -820,9 +791,7 @@ class CallToolResult(Result):
|
|||||||
isError: bool = False
|
isError: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ToolListChangedNotification(
|
class ToolListChangedNotification(Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]]):
|
||||||
Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]]
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
An optional notification from the server to the client, informing it that the list
|
An optional notification from the server to the client, informing it that the list
|
||||||
of tools it offers has changed.
|
of tools it offers has changed.
|
||||||
@@ -832,9 +801,7 @@ class ToolListChangedNotification(
|
|||||||
params: NotificationParams | None = None
|
params: NotificationParams | None = None
|
||||||
|
|
||||||
|
|
||||||
LoggingLevel = Literal[
|
LoggingLevel = Literal["debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"]
|
||||||
"debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class SetLevelRequestParams(RequestParams):
|
class SetLevelRequestParams(RequestParams):
|
||||||
@@ -867,9 +834,7 @@ class LoggingMessageNotificationParams(NotificationParams):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class LoggingMessageNotification(
|
class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]):
|
||||||
Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]
|
|
||||||
):
|
|
||||||
"""Notification of a log message passed from server to client."""
|
"""Notification of a log message passed from server to client."""
|
||||||
|
|
||||||
method: Literal["notifications/message"]
|
method: Literal["notifications/message"]
|
||||||
@@ -964,9 +929,7 @@ class CreateMessageRequestParams(RequestParams):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class CreateMessageRequest(
|
class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]):
|
||||||
Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]
|
|
||||||
):
|
|
||||||
"""A request from the server to sample an LLM via the client."""
|
"""A request from the server to sample an LLM via the client."""
|
||||||
|
|
||||||
method: Literal["sampling/createMessage"]
|
method: Literal["sampling/createMessage"]
|
||||||
@@ -1123,9 +1086,7 @@ class CancelledNotificationParams(NotificationParams):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class CancelledNotification(
|
class CancelledNotification(Notification[CancelledNotificationParams, Literal["notifications/cancelled"]]):
|
||||||
Notification[CancelledNotificationParams, Literal["notifications/cancelled"]]
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
This notification can be sent by either side to indicate that it is canceling a
|
This notification can be sent by either side to indicate that it is canceling a
|
||||||
previously-issued request.
|
previously-issued request.
|
||||||
@@ -1156,12 +1117,7 @@ class ClientRequest(
|
|||||||
|
|
||||||
|
|
||||||
class ClientNotification(
|
class ClientNotification(
|
||||||
RootModel[
|
RootModel[CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification]
|
||||||
CancelledNotification
|
|
||||||
| ProgressNotification
|
|
||||||
| InitializedNotification
|
|
||||||
| RootsListChangedNotification
|
|
||||||
]
|
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -49,8 +49,7 @@ class StreamSpyCollection:
|
|||||||
return [
|
return [
|
||||||
req.message.root
|
req.message.root
|
||||||
for req in self.client.sent_messages
|
for req in self.client.sent_messages
|
||||||
if isinstance(req.message.root, JSONRPCRequest)
|
if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method)
|
||||||
and (method is None or req.message.root.method == method)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
|
def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
|
||||||
@@ -58,13 +57,10 @@ class StreamSpyCollection:
|
|||||||
return [
|
return [
|
||||||
req.message.root
|
req.message.root
|
||||||
for req in self.server.sent_messages
|
for req in self.server.sent_messages
|
||||||
if isinstance(req.message.root, JSONRPCRequest)
|
if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method)
|
||||||
and (method is None or req.message.root.method == method)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_client_notifications(
|
def get_client_notifications(self, method: str | None = None) -> list[JSONRPCNotification]:
|
||||||
self, method: str | None = None
|
|
||||||
) -> list[JSONRPCNotification]:
|
|
||||||
"""Get client-sent notifications, optionally filtered by method."""
|
"""Get client-sent notifications, optionally filtered by method."""
|
||||||
return [
|
return [
|
||||||
notif.message.root
|
notif.message.root
|
||||||
@@ -73,9 +69,7 @@ class StreamSpyCollection:
|
|||||||
and (method is None or notif.message.root.method == method)
|
and (method is None or notif.message.root.method == method)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_server_notifications(
|
def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNotification]:
|
||||||
self, method: str | None = None
|
|
||||||
) -> list[JSONRPCNotification]:
|
|
||||||
"""Get server-sent notifications, optionally filtered by method."""
|
"""Get server-sent notifications, optionally filtered by method."""
|
||||||
return [
|
return [
|
||||||
notif.message.root
|
notif.message.root
|
||||||
@@ -133,9 +127,7 @@ def stream_spy():
|
|||||||
yield (client_read, spy_client_write), (server_read, spy_server_write)
|
yield (client_read, spy_client_write), (server_read, spy_server_write)
|
||||||
|
|
||||||
# Apply the patch for the duration of the test
|
# Apply the patch for the duration of the test
|
||||||
with patch(
|
with patch("mcp.shared.memory.create_client_server_memory_streams", patched_create_streams):
|
||||||
"mcp.shared.memory.create_client_server_memory_streams", patched_create_streams
|
|
||||||
):
|
|
||||||
# Return a collection with helper methods
|
# Return a collection with helper methods
|
||||||
def get_spy_collection() -> StreamSpyCollection:
|
def get_spy_collection() -> StreamSpyCollection:
|
||||||
assert client_spy is not None, "client_spy was not initialized"
|
assert client_spy is not None, "client_spy was not initialized"
|
||||||
|
|||||||
@@ -134,9 +134,7 @@ class TestOAuthClientProvider:
|
|||||||
assert len(verifier) == 128
|
assert len(verifier) == 128
|
||||||
|
|
||||||
# Check charset (RFC 7636: A-Z, a-z, 0-9, "-", ".", "_", "~")
|
# Check charset (RFC 7636: A-Z, a-z, 0-9, "-", ".", "_", "~")
|
||||||
allowed_chars = set(
|
allowed_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~")
|
||||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
|
|
||||||
)
|
|
||||||
assert set(verifier) <= allowed_chars
|
assert set(verifier) <= allowed_chars
|
||||||
|
|
||||||
# Check uniqueness (generate multiple and ensure they're different)
|
# Check uniqueness (generate multiple and ensure they're different)
|
||||||
@@ -151,9 +149,7 @@ class TestOAuthClientProvider:
|
|||||||
|
|
||||||
# Manually calculate expected challenge
|
# Manually calculate expected challenge
|
||||||
expected_digest = hashlib.sha256(verifier.encode()).digest()
|
expected_digest = hashlib.sha256(verifier.encode()).digest()
|
||||||
expected_challenge = (
|
expected_challenge = base64.urlsafe_b64encode(expected_digest).decode().rstrip("=")
|
||||||
base64.urlsafe_b64encode(expected_digest).decode().rstrip("=")
|
|
||||||
)
|
|
||||||
|
|
||||||
assert challenge == expected_challenge
|
assert challenge == expected_challenge
|
||||||
|
|
||||||
@@ -166,29 +162,19 @@ class TestOAuthClientProvider:
|
|||||||
async def test_get_authorization_base_url(self, oauth_provider):
|
async def test_get_authorization_base_url(self, oauth_provider):
|
||||||
"""Test authorization base URL extraction."""
|
"""Test authorization base URL extraction."""
|
||||||
# Test with path
|
# Test with path
|
||||||
assert (
|
assert oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp") == "https://api.example.com"
|
||||||
oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp")
|
|
||||||
== "https://api.example.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test with no path
|
# Test with no path
|
||||||
assert (
|
assert oauth_provider._get_authorization_base_url("https://api.example.com") == "https://api.example.com"
|
||||||
oauth_provider._get_authorization_base_url("https://api.example.com")
|
|
||||||
== "https://api.example.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test with port
|
# Test with port
|
||||||
assert (
|
assert (
|
||||||
oauth_provider._get_authorization_base_url(
|
oauth_provider._get_authorization_base_url("https://api.example.com:8080/path/to/mcp")
|
||||||
"https://api.example.com:8080/path/to/mcp"
|
|
||||||
)
|
|
||||||
== "https://api.example.com:8080"
|
== "https://api.example.com:8080"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_discover_oauth_metadata_success(
|
async def test_discover_oauth_metadata_success(self, oauth_provider, oauth_metadata):
|
||||||
self, oauth_provider, oauth_metadata
|
|
||||||
):
|
|
||||||
"""Test successful OAuth metadata discovery."""
|
"""Test successful OAuth metadata discovery."""
|
||||||
metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json")
|
metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json")
|
||||||
|
|
||||||
@@ -201,23 +187,16 @@ class TestOAuthClientProvider:
|
|||||||
mock_response.json.return_value = metadata_response
|
mock_response.json.return_value = metadata_response
|
||||||
mock_client.get.return_value = mock_response
|
mock_client.get.return_value = mock_response
|
||||||
|
|
||||||
result = await oauth_provider._discover_oauth_metadata(
|
result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp")
|
||||||
"https://api.example.com/v1/mcp"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert (
|
assert result.authorization_endpoint == oauth_metadata.authorization_endpoint
|
||||||
result.authorization_endpoint == oauth_metadata.authorization_endpoint
|
|
||||||
)
|
|
||||||
assert result.token_endpoint == oauth_metadata.token_endpoint
|
assert result.token_endpoint == oauth_metadata.token_endpoint
|
||||||
|
|
||||||
# Verify correct URL was called
|
# Verify correct URL was called
|
||||||
mock_client.get.assert_called_once()
|
mock_client.get.assert_called_once()
|
||||||
call_args = mock_client.get.call_args[0]
|
call_args = mock_client.get.call_args[0]
|
||||||
assert (
|
assert call_args[0] == "https://api.example.com/.well-known/oauth-authorization-server"
|
||||||
call_args[0]
|
|
||||||
== "https://api.example.com/.well-known/oauth-authorization-server"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_discover_oauth_metadata_not_found(self, oauth_provider):
|
async def test_discover_oauth_metadata_not_found(self, oauth_provider):
|
||||||
@@ -230,16 +209,12 @@ class TestOAuthClientProvider:
|
|||||||
mock_response.status_code = 404
|
mock_response.status_code = 404
|
||||||
mock_client.get.return_value = mock_response
|
mock_client.get.return_value = mock_response
|
||||||
|
|
||||||
result = await oauth_provider._discover_oauth_metadata(
|
result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp")
|
||||||
"https://api.example.com/v1/mcp"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_discover_oauth_metadata_cors_fallback(
|
async def test_discover_oauth_metadata_cors_fallback(self, oauth_provider, oauth_metadata):
|
||||||
self, oauth_provider, oauth_metadata
|
|
||||||
):
|
|
||||||
"""Test OAuth metadata discovery with CORS fallback."""
|
"""Test OAuth metadata discovery with CORS fallback."""
|
||||||
metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json")
|
metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json")
|
||||||
|
|
||||||
@@ -257,17 +232,13 @@ class TestOAuthClientProvider:
|
|||||||
mock_response_success, # Second call succeeds
|
mock_response_success, # Second call succeeds
|
||||||
]
|
]
|
||||||
|
|
||||||
result = await oauth_provider._discover_oauth_metadata(
|
result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp")
|
||||||
"https://api.example.com/v1/mcp"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert mock_client.get.call_count == 2
|
assert mock_client.get.call_count == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_register_oauth_client_success(
|
async def test_register_oauth_client_success(self, oauth_provider, oauth_metadata, oauth_client_info):
|
||||||
self, oauth_provider, oauth_metadata, oauth_client_info
|
|
||||||
):
|
|
||||||
"""Test successful OAuth client registration."""
|
"""Test successful OAuth client registration."""
|
||||||
registration_response = oauth_client_info.model_dump(by_alias=True, mode="json")
|
registration_response = oauth_client_info.model_dump(by_alias=True, mode="json")
|
||||||
|
|
||||||
@@ -295,9 +266,7 @@ class TestOAuthClientProvider:
|
|||||||
assert call_args[0][0] == str(oauth_metadata.registration_endpoint)
|
assert call_args[0][0] == str(oauth_metadata.registration_endpoint)
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_register_oauth_client_fallback_endpoint(
|
async def test_register_oauth_client_fallback_endpoint(self, oauth_provider, oauth_client_info):
|
||||||
self, oauth_provider, oauth_client_info
|
|
||||||
):
|
|
||||||
"""Test OAuth client registration with fallback endpoint."""
|
"""Test OAuth client registration with fallback endpoint."""
|
||||||
registration_response = oauth_client_info.model_dump(by_alias=True, mode="json")
|
registration_response = oauth_client_info.model_dump(by_alias=True, mode="json")
|
||||||
|
|
||||||
@@ -311,9 +280,7 @@ class TestOAuthClientProvider:
|
|||||||
mock_client.post.return_value = mock_response
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
# Mock metadata discovery to return None (fallback)
|
# Mock metadata discovery to return None (fallback)
|
||||||
with patch.object(
|
with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None):
|
||||||
oauth_provider, "_discover_oauth_metadata", return_value=None
|
|
||||||
):
|
|
||||||
result = await oauth_provider._register_oauth_client(
|
result = await oauth_provider._register_oauth_client(
|
||||||
"https://api.example.com/v1/mcp",
|
"https://api.example.com/v1/mcp",
|
||||||
oauth_provider.client_metadata,
|
oauth_provider.client_metadata,
|
||||||
@@ -340,9 +307,7 @@ class TestOAuthClientProvider:
|
|||||||
mock_client.post.return_value = mock_response
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
# Mock metadata discovery to return None (fallback)
|
# Mock metadata discovery to return None (fallback)
|
||||||
with patch.object(
|
with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None):
|
||||||
oauth_provider, "_discover_oauth_metadata", return_value=None
|
|
||||||
):
|
|
||||||
with pytest.raises(httpx.HTTPStatusError):
|
with pytest.raises(httpx.HTTPStatusError):
|
||||||
await oauth_provider._register_oauth_client(
|
await oauth_provider._register_oauth_client(
|
||||||
"https://api.example.com/v1/mcp",
|
"https://api.example.com/v1/mcp",
|
||||||
@@ -406,9 +371,7 @@ class TestOAuthClientProvider:
|
|||||||
await oauth_provider._validate_token_scopes(token)
|
await oauth_provider._validate_token_scopes(token)
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_validate_token_scopes_unauthorized(
|
async def test_validate_token_scopes_unauthorized(self, oauth_provider, client_metadata):
|
||||||
self, oauth_provider, client_metadata
|
|
||||||
):
|
|
||||||
"""Test scope validation with unauthorized scopes."""
|
"""Test scope validation with unauthorized scopes."""
|
||||||
oauth_provider.client_metadata = client_metadata
|
oauth_provider.client_metadata = client_metadata
|
||||||
token = OAuthToken(
|
token = OAuthToken(
|
||||||
@@ -436,9 +399,7 @@ class TestOAuthClientProvider:
|
|||||||
await oauth_provider._validate_token_scopes(token)
|
await oauth_provider._validate_token_scopes(token)
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_initialize(
|
async def test_initialize(self, oauth_provider, mock_storage, oauth_token, oauth_client_info):
|
||||||
self, oauth_provider, mock_storage, oauth_token, oauth_client_info
|
|
||||||
):
|
|
||||||
"""Test initialization loading from storage."""
|
"""Test initialization loading from storage."""
|
||||||
mock_storage._tokens = oauth_token
|
mock_storage._tokens = oauth_token
|
||||||
mock_storage._client_info = oauth_client_info
|
mock_storage._client_info = oauth_client_info
|
||||||
@@ -449,9 +410,7 @@ class TestOAuthClientProvider:
|
|||||||
assert oauth_provider._client_info == oauth_client_info
|
assert oauth_provider._client_info == oauth_client_info
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_or_register_client_existing(
|
async def test_get_or_register_client_existing(self, oauth_provider, oauth_client_info):
|
||||||
self, oauth_provider, oauth_client_info
|
|
||||||
):
|
|
||||||
"""Test getting existing client info."""
|
"""Test getting existing client info."""
|
||||||
oauth_provider._client_info = oauth_client_info
|
oauth_provider._client_info = oauth_client_info
|
||||||
|
|
||||||
@@ -460,13 +419,9 @@ class TestOAuthClientProvider:
|
|||||||
assert result == oauth_client_info
|
assert result == oauth_client_info
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_or_register_client_register_new(
|
async def test_get_or_register_client_register_new(self, oauth_provider, oauth_client_info):
|
||||||
self, oauth_provider, oauth_client_info
|
|
||||||
):
|
|
||||||
"""Test registering new client."""
|
"""Test registering new client."""
|
||||||
with patch.object(
|
with patch.object(oauth_provider, "_register_oauth_client", return_value=oauth_client_info) as mock_register:
|
||||||
oauth_provider, "_register_oauth_client", return_value=oauth_client_info
|
|
||||||
) as mock_register:
|
|
||||||
result = await oauth_provider._get_or_register_client()
|
result = await oauth_provider._get_or_register_client()
|
||||||
|
|
||||||
assert result == oauth_client_info
|
assert result == oauth_client_info
|
||||||
@@ -474,9 +429,7 @@ class TestOAuthClientProvider:
|
|||||||
mock_register.assert_called_once()
|
mock_register.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_exchange_code_for_token_success(
|
async def test_exchange_code_for_token_success(self, oauth_provider, oauth_client_info, oauth_token):
|
||||||
self, oauth_provider, oauth_client_info, oauth_token
|
|
||||||
):
|
|
||||||
"""Test successful code exchange for token."""
|
"""Test successful code exchange for token."""
|
||||||
oauth_provider._code_verifier = "test_verifier"
|
oauth_provider._code_verifier = "test_verifier"
|
||||||
token_response = oauth_token.model_dump(by_alias=True, mode="json")
|
token_response = oauth_token.model_dump(by_alias=True, mode="json")
|
||||||
@@ -490,23 +443,14 @@ class TestOAuthClientProvider:
|
|||||||
mock_response.json.return_value = token_response
|
mock_response.json.return_value = token_response
|
||||||
mock_client.post.return_value = mock_response
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(oauth_provider, "_validate_token_scopes") as mock_validate:
|
||||||
oauth_provider, "_validate_token_scopes"
|
await oauth_provider._exchange_code_for_token("test_auth_code", oauth_client_info)
|
||||||
) as mock_validate:
|
|
||||||
await oauth_provider._exchange_code_for_token(
|
|
||||||
"test_auth_code", oauth_client_info
|
|
||||||
)
|
|
||||||
|
|
||||||
assert (
|
assert oauth_provider._current_tokens.access_token == oauth_token.access_token
|
||||||
oauth_provider._current_tokens.access_token
|
|
||||||
== oauth_token.access_token
|
|
||||||
)
|
|
||||||
mock_validate.assert_called_once()
|
mock_validate.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_exchange_code_for_token_failure(
|
async def test_exchange_code_for_token_failure(self, oauth_provider, oauth_client_info):
|
||||||
self, oauth_provider, oauth_client_info
|
|
||||||
):
|
|
||||||
"""Test failed code exchange for token."""
|
"""Test failed code exchange for token."""
|
||||||
oauth_provider._code_verifier = "test_verifier"
|
oauth_provider._code_verifier = "test_verifier"
|
||||||
|
|
||||||
@@ -520,14 +464,10 @@ class TestOAuthClientProvider:
|
|||||||
mock_client.post.return_value = mock_response
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
with pytest.raises(Exception, match="Token exchange failed"):
|
with pytest.raises(Exception, match="Token exchange failed"):
|
||||||
await oauth_provider._exchange_code_for_token(
|
await oauth_provider._exchange_code_for_token("invalid_auth_code", oauth_client_info)
|
||||||
"invalid_auth_code", oauth_client_info
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_refresh_access_token_success(
|
async def test_refresh_access_token_success(self, oauth_provider, oauth_client_info, oauth_token):
|
||||||
self, oauth_provider, oauth_client_info, oauth_token
|
|
||||||
):
|
|
||||||
"""Test successful token refresh."""
|
"""Test successful token refresh."""
|
||||||
oauth_provider._current_tokens = oauth_token
|
oauth_provider._current_tokens = oauth_token
|
||||||
oauth_provider._client_info = oauth_client_info
|
oauth_provider._client_info = oauth_client_info
|
||||||
@@ -550,16 +490,11 @@ class TestOAuthClientProvider:
|
|||||||
mock_response.json.return_value = token_response
|
mock_response.json.return_value = token_response
|
||||||
mock_client.post.return_value = mock_response
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(oauth_provider, "_validate_token_scopes") as mock_validate:
|
||||||
oauth_provider, "_validate_token_scopes"
|
|
||||||
) as mock_validate:
|
|
||||||
result = await oauth_provider._refresh_access_token()
|
result = await oauth_provider._refresh_access_token()
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
assert (
|
assert oauth_provider._current_tokens.access_token == new_token.access_token
|
||||||
oauth_provider._current_tokens.access_token
|
|
||||||
== new_token.access_token
|
|
||||||
)
|
|
||||||
mock_validate.assert_called_once()
|
mock_validate.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -575,9 +510,7 @@ class TestOAuthClientProvider:
|
|||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_refresh_access_token_failure(
|
async def test_refresh_access_token_failure(self, oauth_provider, oauth_client_info, oauth_token):
|
||||||
self, oauth_provider, oauth_client_info, oauth_token
|
|
||||||
):
|
|
||||||
"""Test failed token refresh."""
|
"""Test failed token refresh."""
|
||||||
oauth_provider._current_tokens = oauth_token
|
oauth_provider._current_tokens = oauth_token
|
||||||
oauth_provider._client_info = oauth_client_info
|
oauth_provider._client_info = oauth_client_info
|
||||||
@@ -594,9 +527,7 @@ class TestOAuthClientProvider:
|
|||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_perform_oauth_flow_success(
|
async def test_perform_oauth_flow_success(self, oauth_provider, oauth_metadata, oauth_client_info):
|
||||||
self, oauth_provider, oauth_metadata, oauth_client_info
|
|
||||||
):
|
|
||||||
"""Test successful OAuth flow."""
|
"""Test successful OAuth flow."""
|
||||||
oauth_provider._metadata = oauth_metadata
|
oauth_provider._metadata = oauth_metadata
|
||||||
oauth_provider._client_info = oauth_client_info
|
oauth_provider._client_info = oauth_client_info
|
||||||
@@ -640,9 +571,7 @@ class TestOAuthClientProvider:
|
|||||||
mock_exchange.assert_called_once_with("test_auth_code", oauth_client_info)
|
mock_exchange.assert_called_once_with("test_auth_code", oauth_client_info)
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_perform_oauth_flow_state_mismatch(
|
async def test_perform_oauth_flow_state_mismatch(self, oauth_provider, oauth_metadata, oauth_client_info):
|
||||||
self, oauth_provider, oauth_metadata, oauth_client_info
|
|
||||||
):
|
|
||||||
"""Test OAuth flow with state parameter mismatch."""
|
"""Test OAuth flow with state parameter mismatch."""
|
||||||
oauth_provider._metadata = oauth_metadata
|
oauth_provider._metadata = oauth_metadata
|
||||||
oauth_provider._client_info = oauth_client_info
|
oauth_provider._client_info = oauth_client_info
|
||||||
@@ -678,9 +607,7 @@ class TestOAuthClientProvider:
|
|||||||
oauth_provider._current_tokens = oauth_token
|
oauth_provider._current_tokens = oauth_token
|
||||||
oauth_provider._token_expiry_time = time.time() - 3600 # Expired
|
oauth_provider._token_expiry_time = time.time() - 3600 # Expired
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(oauth_provider, "_refresh_access_token", return_value=True) as mock_refresh:
|
||||||
oauth_provider, "_refresh_access_token", return_value=True
|
|
||||||
) as mock_refresh:
|
|
||||||
await oauth_provider.ensure_token()
|
await oauth_provider.ensure_token()
|
||||||
mock_refresh.assert_called_once()
|
mock_refresh.assert_called_once()
|
||||||
|
|
||||||
@@ -707,10 +634,7 @@ class TestOAuthClientProvider:
|
|||||||
auth_flow = oauth_provider.async_auth_flow(request)
|
auth_flow = oauth_provider.async_auth_flow(request)
|
||||||
updated_request = await auth_flow.__anext__()
|
updated_request = await auth_flow.__anext__()
|
||||||
|
|
||||||
assert (
|
assert updated_request.headers["Authorization"] == f"Bearer {oauth_token.access_token}"
|
||||||
updated_request.headers["Authorization"]
|
|
||||||
== f"Bearer {oauth_token.access_token}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send mock response
|
# Send mock response
|
||||||
try:
|
try:
|
||||||
@@ -761,9 +685,7 @@ class TestOAuthClientProvider:
|
|||||||
assert "Authorization" not in updated_request.headers
|
assert "Authorization" not in updated_request.headers
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_scope_priority_client_metadata_first(
|
async def test_scope_priority_client_metadata_first(self, oauth_provider, oauth_client_info):
|
||||||
self, oauth_provider, oauth_client_info
|
|
||||||
):
|
|
||||||
"""Test that client metadata scope takes priority."""
|
"""Test that client metadata scope takes priority."""
|
||||||
oauth_provider.client_metadata.scope = "read write"
|
oauth_provider.client_metadata.scope = "read write"
|
||||||
oauth_provider._client_info = oauth_client_info
|
oauth_provider._client_info = oauth_client_info
|
||||||
@@ -782,18 +704,13 @@ class TestOAuthClientProvider:
|
|||||||
# Apply scope logic from _perform_oauth_flow
|
# Apply scope logic from _perform_oauth_flow
|
||||||
if oauth_provider.client_metadata.scope:
|
if oauth_provider.client_metadata.scope:
|
||||||
auth_params["scope"] = oauth_provider.client_metadata.scope
|
auth_params["scope"] = oauth_provider.client_metadata.scope
|
||||||
elif (
|
elif hasattr(oauth_provider._client_info, "scope") and oauth_provider._client_info.scope:
|
||||||
hasattr(oauth_provider._client_info, "scope")
|
|
||||||
and oauth_provider._client_info.scope
|
|
||||||
):
|
|
||||||
auth_params["scope"] = oauth_provider._client_info.scope
|
auth_params["scope"] = oauth_provider._client_info.scope
|
||||||
|
|
||||||
assert auth_params["scope"] == "read write"
|
assert auth_params["scope"] == "read write"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_scope_priority_no_client_metadata_scope(
|
async def test_scope_priority_no_client_metadata_scope(self, oauth_provider, oauth_client_info):
|
||||||
self, oauth_provider, oauth_client_info
|
|
||||||
):
|
|
||||||
"""Test that no scope parameter is set when client metadata has no scope."""
|
"""Test that no scope parameter is set when client metadata has no scope."""
|
||||||
oauth_provider.client_metadata.scope = None
|
oauth_provider.client_metadata.scope = None
|
||||||
oauth_provider._client_info = oauth_client_info
|
oauth_provider._client_info = oauth_client_info
|
||||||
@@ -837,10 +754,7 @@ class TestOAuthClientProvider:
|
|||||||
# Apply scope logic from _perform_oauth_flow
|
# Apply scope logic from _perform_oauth_flow
|
||||||
if oauth_provider.client_metadata.scope:
|
if oauth_provider.client_metadata.scope:
|
||||||
auth_params["scope"] = oauth_provider.client_metadata.scope
|
auth_params["scope"] = oauth_provider.client_metadata.scope
|
||||||
elif (
|
elif hasattr(oauth_provider._client_info, "scope") and oauth_provider._client_info.scope:
|
||||||
hasattr(oauth_provider._client_info, "scope")
|
|
||||||
and oauth_provider._client_info.scope
|
|
||||||
):
|
|
||||||
auth_params["scope"] = oauth_provider._client_info.scope
|
auth_params["scope"] = oauth_provider._client_info.scope
|
||||||
|
|
||||||
# No scope should be set
|
# No scope should be set
|
||||||
@@ -866,9 +780,7 @@ class TestOAuthClientProvider:
|
|||||||
oauth_provider.redirect_handler = mock_redirect_handler
|
oauth_provider.redirect_handler = mock_redirect_handler
|
||||||
|
|
||||||
# Patch secrets.compare_digest to verify it's being called
|
# Patch secrets.compare_digest to verify it's being called
|
||||||
with patch(
|
with patch("mcp.client.auth.secrets.compare_digest", return_value=False) as mock_compare:
|
||||||
"mcp.client.auth.secrets.compare_digest", return_value=False
|
|
||||||
) as mock_compare:
|
|
||||||
with pytest.raises(Exception, match="State parameter mismatch"):
|
with pytest.raises(Exception, match="State parameter mismatch"):
|
||||||
await oauth_provider._perform_oauth_flow()
|
await oauth_provider._perform_oauth_flow()
|
||||||
|
|
||||||
@@ -876,9 +788,7 @@ class TestOAuthClientProvider:
|
|||||||
mock_compare.assert_called_once()
|
mock_compare.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_state_parameter_validation_none_state(
|
async def test_state_parameter_validation_none_state(self, oauth_provider, oauth_metadata, oauth_client_info):
|
||||||
self, oauth_provider, oauth_metadata, oauth_client_info
|
|
||||||
):
|
|
||||||
"""Test that None state is handled correctly."""
|
"""Test that None state is handled correctly."""
|
||||||
oauth_provider._metadata = oauth_metadata
|
oauth_provider._metadata = oauth_metadata
|
||||||
oauth_provider._client_info = oauth_client_info
|
oauth_provider._client_info = oauth_client_info
|
||||||
@@ -913,9 +823,7 @@ class TestOAuthClientProvider:
|
|||||||
mock_client.post.return_value = mock_response
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
with pytest.raises(Exception, match="Token exchange failed"):
|
with pytest.raises(Exception, match="Token exchange failed"):
|
||||||
await oauth_provider._exchange_code_for_token(
|
await oauth_provider._exchange_code_for_token("invalid_auth_code", oauth_client_info)
|
||||||
"invalid_auth_code", oauth_client_info
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -968,9 +876,7 @@ def test_build_metadata(
|
|||||||
metadata = build_metadata(
|
metadata = build_metadata(
|
||||||
issuer_url=AnyHttpUrl(issuer_url),
|
issuer_url=AnyHttpUrl(issuer_url),
|
||||||
service_documentation_url=AnyHttpUrl(service_documentation_url),
|
service_documentation_url=AnyHttpUrl(service_documentation_url),
|
||||||
client_registration_options=ClientRegistrationOptions(
|
client_registration_options=ClientRegistrationOptions(enabled=True, valid_scopes=["read", "write", "admin"]),
|
||||||
enabled=True, valid_scopes=["read", "write", "admin"]
|
|
||||||
),
|
|
||||||
revocation_options=RevocationOptions(enabled=True),
|
revocation_options=RevocationOptions(enabled=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -44,9 +44,7 @@ def test_command_execution(mock_config_path: Path):
|
|||||||
|
|
||||||
test_args = [command] + args + ["--help"]
|
test_args = [command] + args + ["--help"]
|
||||||
|
|
||||||
result = subprocess.run(
|
result = subprocess.run(test_args, capture_output=True, text=True, timeout=5, check=False)
|
||||||
test_args, capture_output=True, text=True, timeout=5, check=False
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.returncode == 0
|
assert result.returncode == 0
|
||||||
assert "usage" in result.stdout.lower()
|
assert "usage" in result.stdout.lower()
|
||||||
|
|||||||
@@ -182,9 +182,7 @@ async def test_list_resource_templates_cursor_parameter(stream_spy):
|
|||||||
|
|
||||||
# Test without cursor parameter (omitted)
|
# Test without cursor parameter (omitted)
|
||||||
_ = await client_session.list_resource_templates()
|
_ = await client_session.list_resource_templates()
|
||||||
list_templates_requests = spies.get_client_requests(
|
list_templates_requests = spies.get_client_requests(method="resources/templates/list")
|
||||||
method="resources/templates/list"
|
|
||||||
)
|
|
||||||
assert len(list_templates_requests) == 1
|
assert len(list_templates_requests) == 1
|
||||||
assert list_templates_requests[0].params is None
|
assert list_templates_requests[0].params is None
|
||||||
|
|
||||||
@@ -192,9 +190,7 @@ async def test_list_resource_templates_cursor_parameter(stream_spy):
|
|||||||
|
|
||||||
# Test with cursor=None
|
# Test with cursor=None
|
||||||
_ = await client_session.list_resource_templates(cursor=None)
|
_ = await client_session.list_resource_templates(cursor=None)
|
||||||
list_templates_requests = spies.get_client_requests(
|
list_templates_requests = spies.get_client_requests(method="resources/templates/list")
|
||||||
method="resources/templates/list"
|
|
||||||
)
|
|
||||||
assert len(list_templates_requests) == 1
|
assert len(list_templates_requests) == 1
|
||||||
assert list_templates_requests[0].params is None
|
assert list_templates_requests[0].params is None
|
||||||
|
|
||||||
@@ -202,9 +198,7 @@ async def test_list_resource_templates_cursor_parameter(stream_spy):
|
|||||||
|
|
||||||
# Test with cursor as string
|
# Test with cursor as string
|
||||||
_ = await client_session.list_resource_templates(cursor="some_cursor")
|
_ = await client_session.list_resource_templates(cursor="some_cursor")
|
||||||
list_templates_requests = spies.get_client_requests(
|
list_templates_requests = spies.get_client_requests(method="resources/templates/list")
|
||||||
method="resources/templates/list"
|
|
||||||
)
|
|
||||||
assert len(list_templates_requests) == 1
|
assert len(list_templates_requests) == 1
|
||||||
assert list_templates_requests[0].params is not None
|
assert list_templates_requests[0].params is not None
|
||||||
assert list_templates_requests[0].params["cursor"] == "some_cursor"
|
assert list_templates_requests[0].params["cursor"] == "some_cursor"
|
||||||
@@ -213,9 +207,7 @@ async def test_list_resource_templates_cursor_parameter(stream_spy):
|
|||||||
|
|
||||||
# Test with empty string cursor
|
# Test with empty string cursor
|
||||||
_ = await client_session.list_resource_templates(cursor="")
|
_ = await client_session.list_resource_templates(cursor="")
|
||||||
list_templates_requests = spies.get_client_requests(
|
list_templates_requests = spies.get_client_requests(method="resources/templates/list")
|
||||||
method="resources/templates/list"
|
|
||||||
)
|
|
||||||
assert len(list_templates_requests) == 1
|
assert len(list_templates_requests) == 1
|
||||||
assert list_templates_requests[0].params is not None
|
assert list_templates_requests[0].params is not None
|
||||||
assert list_templates_requests[0].params["cursor"] == ""
|
assert list_templates_requests[0].params["cursor"] == ""
|
||||||
|
|||||||
@@ -41,13 +41,9 @@ async def test_list_roots_callback():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# Test with list_roots callback
|
# Test with list_roots callback
|
||||||
async with create_session(
|
async with create_session(server._mcp_server, list_roots_callback=list_roots_callback) as client_session:
|
||||||
server._mcp_server, list_roots_callback=list_roots_callback
|
|
||||||
) as client_session:
|
|
||||||
# Make a request to trigger sampling callback
|
# Make a request to trigger sampling callback
|
||||||
result = await client_session.call_tool(
|
result = await client_session.call_tool("test_list_roots", {"message": "test message"})
|
||||||
"test_list_roots", {"message": "test message"}
|
|
||||||
)
|
|
||||||
assert result.isError is False
|
assert result.isError is False
|
||||||
assert isinstance(result.content[0], TextContent)
|
assert isinstance(result.content[0], TextContent)
|
||||||
assert result.content[0].text == "true"
|
assert result.content[0].text == "true"
|
||||||
@@ -55,12 +51,7 @@ async def test_list_roots_callback():
|
|||||||
# Test without list_roots callback
|
# Test without list_roots callback
|
||||||
async with create_session(server._mcp_server) as client_session:
|
async with create_session(server._mcp_server) as client_session:
|
||||||
# Make a request to trigger sampling callback
|
# Make a request to trigger sampling callback
|
||||||
result = await client_session.call_tool(
|
result = await client_session.call_tool("test_list_roots", {"message": "test message"})
|
||||||
"test_list_roots", {"message": "test message"}
|
|
||||||
)
|
|
||||||
assert result.isError is True
|
assert result.isError is True
|
||||||
assert isinstance(result.content[0], TextContent)
|
assert isinstance(result.content[0], TextContent)
|
||||||
assert (
|
assert result.content[0].text == "Error executing tool test_list_roots: List roots not supported"
|
||||||
result.content[0].text
|
|
||||||
== "Error executing tool test_list_roots: List roots not supported"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -49,9 +49,7 @@ async def test_logging_callback():
|
|||||||
|
|
||||||
# Create a message handler to catch exceptions
|
# Create a message handler to catch exceptions
|
||||||
async def message_handler(
|
async def message_handler(
|
||||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||||
| types.ServerNotification
|
|
||||||
| Exception,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(message, Exception):
|
if isinstance(message, Exception):
|
||||||
raise message
|
raise message
|
||||||
|
|||||||
@@ -21,9 +21,7 @@ async def test_sampling_callback():
|
|||||||
|
|
||||||
callback_return = CreateMessageResult(
|
callback_return = CreateMessageResult(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=TextContent(
|
content=TextContent(type="text", text="This is a response from the sampling callback"),
|
||||||
type="text", text="This is a response from the sampling callback"
|
|
||||||
),
|
|
||||||
model="test-model",
|
model="test-model",
|
||||||
stopReason="endTurn",
|
stopReason="endTurn",
|
||||||
)
|
)
|
||||||
@@ -37,24 +35,16 @@ async def test_sampling_callback():
|
|||||||
@server.tool("test_sampling")
|
@server.tool("test_sampling")
|
||||||
async def test_sampling_tool(message: str):
|
async def test_sampling_tool(message: str):
|
||||||
value = await server.get_context().session.create_message(
|
value = await server.get_context().session.create_message(
|
||||||
messages=[
|
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
|
||||||
SamplingMessage(
|
|
||||||
role="user", content=TextContent(type="text", text=message)
|
|
||||||
)
|
|
||||||
],
|
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
)
|
)
|
||||||
assert value == callback_return
|
assert value == callback_return
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Test with sampling callback
|
# Test with sampling callback
|
||||||
async with create_session(
|
async with create_session(server._mcp_server, sampling_callback=sampling_callback) as client_session:
|
||||||
server._mcp_server, sampling_callback=sampling_callback
|
|
||||||
) as client_session:
|
|
||||||
# Make a request to trigger sampling callback
|
# Make a request to trigger sampling callback
|
||||||
result = await client_session.call_tool(
|
result = await client_session.call_tool("test_sampling", {"message": "Test message for sampling"})
|
||||||
"test_sampling", {"message": "Test message for sampling"}
|
|
||||||
)
|
|
||||||
assert result.isError is False
|
assert result.isError is False
|
||||||
assert isinstance(result.content[0], TextContent)
|
assert isinstance(result.content[0], TextContent)
|
||||||
assert result.content[0].text == "true"
|
assert result.content[0].text == "true"
|
||||||
@@ -62,12 +52,7 @@ async def test_sampling_callback():
|
|||||||
# Test without sampling callback
|
# Test without sampling callback
|
||||||
async with create_session(server._mcp_server) as client_session:
|
async with create_session(server._mcp_server) as client_session:
|
||||||
# Make a request to trigger sampling callback
|
# Make a request to trigger sampling callback
|
||||||
result = await client_session.call_tool(
|
result = await client_session.call_tool("test_sampling", {"message": "Test message for sampling"})
|
||||||
"test_sampling", {"message": "Test message for sampling"}
|
|
||||||
)
|
|
||||||
assert result.isError is True
|
assert result.isError is True
|
||||||
assert isinstance(result.content[0], TextContent)
|
assert isinstance(result.content[0], TextContent)
|
||||||
assert (
|
assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported"
|
||||||
result.content[0].text
|
|
||||||
== "Error executing tool test_sampling: Sampling not supported"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -28,12 +28,8 @@ from mcp.types import (
|
|||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_session_initialize():
|
async def test_client_session_initialize():
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
SessionMessage
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
](1)
|
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage
|
|
||||||
](1)
|
|
||||||
|
|
||||||
initialized_notification = None
|
initialized_notification = None
|
||||||
|
|
||||||
@@ -70,9 +66,7 @@ async def test_client_session_initialize():
|
|||||||
JSONRPCResponse(
|
JSONRPCResponse(
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
id=jsonrpc_request.root.id,
|
id=jsonrpc_request.root.id,
|
||||||
result=result.model_dump(
|
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||||
by_alias=True, mode="json", exclude_none=True
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -81,16 +75,12 @@ async def test_client_session_initialize():
|
|||||||
jsonrpc_notification = session_notification.message
|
jsonrpc_notification = session_notification.message
|
||||||
assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
|
assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
|
||||||
initialized_notification = ClientNotification.model_validate(
|
initialized_notification = ClientNotification.model_validate(
|
||||||
jsonrpc_notification.model_dump(
|
jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||||
by_alias=True, mode="json", exclude_none=True
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a message handler to catch exceptions
|
# Create a message handler to catch exceptions
|
||||||
async def message_handler(
|
async def message_handler(
|
||||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||||
| types.ServerNotification
|
|
||||||
| Exception,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(message, Exception):
|
if isinstance(message, Exception):
|
||||||
raise message
|
raise message
|
||||||
@@ -124,12 +114,8 @@ async def test_client_session_initialize():
|
|||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_session_custom_client_info():
|
async def test_client_session_custom_client_info():
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
SessionMessage
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
](1)
|
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage
|
|
||||||
](1)
|
|
||||||
|
|
||||||
custom_client_info = Implementation(name="test-client", version="1.2.3")
|
custom_client_info = Implementation(name="test-client", version="1.2.3")
|
||||||
received_client_info = None
|
received_client_info = None
|
||||||
@@ -161,9 +147,7 @@ async def test_client_session_custom_client_info():
|
|||||||
JSONRPCResponse(
|
JSONRPCResponse(
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
id=jsonrpc_request.root.id,
|
id=jsonrpc_request.root.id,
|
||||||
result=result.model_dump(
|
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||||
by_alias=True, mode="json", exclude_none=True
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -192,12 +176,8 @@ async def test_client_session_custom_client_info():
|
|||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_session_default_client_info():
|
async def test_client_session_default_client_info():
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
SessionMessage
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
](1)
|
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage
|
|
||||||
](1)
|
|
||||||
|
|
||||||
received_client_info = None
|
received_client_info = None
|
||||||
|
|
||||||
@@ -228,9 +208,7 @@ async def test_client_session_default_client_info():
|
|||||||
JSONRPCResponse(
|
JSONRPCResponse(
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
id=jsonrpc_request.root.id,
|
id=jsonrpc_request.root.id,
|
||||||
result=result.model_dump(
|
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||||
by_alias=True, mode="json", exclude_none=True
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -259,12 +237,8 @@ async def test_client_session_default_client_info():
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_session_version_negotiation_success():
|
async def test_client_session_version_negotiation_success():
|
||||||
"""Test successful version negotiation with supported version"""
|
"""Test successful version negotiation with supported version"""
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
SessionMessage
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
](1)
|
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage
|
|
||||||
](1)
|
|
||||||
|
|
||||||
async def mock_server():
|
async def mock_server():
|
||||||
session_message = await client_to_server_receive.receive()
|
session_message = await client_to_server_receive.receive()
|
||||||
@@ -294,9 +268,7 @@ async def test_client_session_version_negotiation_success():
|
|||||||
JSONRPCResponse(
|
JSONRPCResponse(
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
id=jsonrpc_request.root.id,
|
id=jsonrpc_request.root.id,
|
||||||
result=result.model_dump(
|
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||||
by_alias=True, mode="json", exclude_none=True
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -327,12 +299,8 @@ async def test_client_session_version_negotiation_success():
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_session_version_negotiation_failure():
|
async def test_client_session_version_negotiation_failure():
|
||||||
"""Test version negotiation failure with unsupported version"""
|
"""Test version negotiation failure with unsupported version"""
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
SessionMessage
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
](1)
|
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage
|
|
||||||
](1)
|
|
||||||
|
|
||||||
async def mock_server():
|
async def mock_server():
|
||||||
session_message = await client_to_server_receive.receive()
|
session_message = await client_to_server_receive.receive()
|
||||||
@@ -359,9 +327,7 @@ async def test_client_session_version_negotiation_failure():
|
|||||||
JSONRPCResponse(
|
JSONRPCResponse(
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
id=jsonrpc_request.root.id,
|
id=jsonrpc_request.root.id,
|
||||||
result=result.model_dump(
|
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||||
by_alias=True, mode="json", exclude_none=True
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -388,12 +354,8 @@ async def test_client_session_version_negotiation_failure():
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_capabilities_default():
|
async def test_client_capabilities_default():
|
||||||
"""Test that client capabilities are properly set with default callbacks"""
|
"""Test that client capabilities are properly set with default callbacks"""
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
SessionMessage
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
](1)
|
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage
|
|
||||||
](1)
|
|
||||||
|
|
||||||
received_capabilities = None
|
received_capabilities = None
|
||||||
|
|
||||||
@@ -424,9 +386,7 @@ async def test_client_capabilities_default():
|
|||||||
JSONRPCResponse(
|
JSONRPCResponse(
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
id=jsonrpc_request.root.id,
|
id=jsonrpc_request.root.id,
|
||||||
result=result.model_dump(
|
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||||
by_alias=True, mode="json", exclude_none=True
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -457,12 +417,8 @@ async def test_client_capabilities_default():
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_capabilities_with_custom_callbacks():
|
async def test_client_capabilities_with_custom_callbacks():
|
||||||
"""Test that client capabilities are properly set with custom callbacks"""
|
"""Test that client capabilities are properly set with custom callbacks"""
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
SessionMessage
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
](1)
|
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage
|
|
||||||
](1)
|
|
||||||
|
|
||||||
received_capabilities = None
|
received_capabilities = None
|
||||||
|
|
||||||
@@ -508,9 +464,7 @@ async def test_client_capabilities_with_custom_callbacks():
|
|||||||
JSONRPCResponse(
|
JSONRPCResponse(
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
id=jsonrpc_request.root.id,
|
id=jsonrpc_request.root.id,
|
||||||
result=result.model_dump(
|
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||||
by_alias=True, mode="json", exclude_none=True
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -536,14 +490,8 @@ async def test_client_capabilities_with_custom_callbacks():
|
|||||||
|
|
||||||
# Assert that capabilities are properly set with custom callbacks
|
# Assert that capabilities are properly set with custom callbacks
|
||||||
assert received_capabilities is not None
|
assert received_capabilities is not None
|
||||||
assert (
|
assert received_capabilities.sampling is not None # Custom sampling callback provided
|
||||||
received_capabilities.sampling is not None
|
|
||||||
) # Custom sampling callback provided
|
|
||||||
assert isinstance(received_capabilities.sampling, types.SamplingCapability)
|
assert isinstance(received_capabilities.sampling, types.SamplingCapability)
|
||||||
assert (
|
assert received_capabilities.roots is not None # Custom list_roots callback provided
|
||||||
received_capabilities.roots is not None
|
|
||||||
) # Custom list_roots callback provided
|
|
||||||
assert isinstance(received_capabilities.roots, types.RootsCapability)
|
assert isinstance(received_capabilities.roots, types.RootsCapability)
|
||||||
assert (
|
assert received_capabilities.roots.listChanged is True # Should be True for custom callback
|
||||||
received_capabilities.roots.listChanged is True
|
|
||||||
) # Should be True for custom callback
|
|
||||||
|
|||||||
@@ -58,14 +58,10 @@ class TestClientSessionGroup:
|
|||||||
return f"{(server_info.name)}-{name}"
|
return f"{(server_info.name)}-{name}"
|
||||||
|
|
||||||
mcp_session_group = ClientSessionGroup(component_name_hook=hook)
|
mcp_session_group = ClientSessionGroup(component_name_hook=hook)
|
||||||
mcp_session_group._tools = {
|
mcp_session_group._tools = {"server1-my_tool": types.Tool(name="my_tool", inputSchema={})}
|
||||||
"server1-my_tool": types.Tool(name="my_tool", inputSchema={})
|
|
||||||
}
|
|
||||||
mcp_session_group._tool_to_session = {"server1-my_tool": mock_session}
|
mcp_session_group._tool_to_session = {"server1-my_tool": mock_session}
|
||||||
text_content = types.TextContent(type="text", text="OK")
|
text_content = types.TextContent(type="text", text="OK")
|
||||||
mock_session.call_tool.return_value = types.CallToolResult(
|
mock_session.call_tool.return_value = types.CallToolResult(content=[text_content])
|
||||||
content=[text_content]
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Test Execution ---
|
# --- Test Execution ---
|
||||||
result = await mcp_session_group.call_tool(
|
result = await mcp_session_group.call_tool(
|
||||||
@@ -96,16 +92,12 @@ class TestClientSessionGroup:
|
|||||||
mock_prompt1 = mock.Mock(spec=types.Prompt)
|
mock_prompt1 = mock.Mock(spec=types.Prompt)
|
||||||
mock_prompt1.name = "prompt_c"
|
mock_prompt1.name = "prompt_c"
|
||||||
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1])
|
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1])
|
||||||
mock_session.list_resources.return_value = mock.AsyncMock(
|
mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1])
|
||||||
resources=[mock_resource1]
|
|
||||||
)
|
|
||||||
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1])
|
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1])
|
||||||
|
|
||||||
# --- Test Execution ---
|
# --- Test Execution ---
|
||||||
group = ClientSessionGroup(exit_stack=mock_exit_stack)
|
group = ClientSessionGroup(exit_stack=mock_exit_stack)
|
||||||
with mock.patch.object(
|
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
|
||||||
group, "_establish_session", return_value=(mock_server_info, mock_session)
|
|
||||||
):
|
|
||||||
await group.connect_to_server(StdioServerParameters(command="test"))
|
await group.connect_to_server(StdioServerParameters(command="test"))
|
||||||
|
|
||||||
# --- Assertions ---
|
# --- Assertions ---
|
||||||
@@ -141,12 +133,8 @@ class TestClientSessionGroup:
|
|||||||
return f"{server_info.name}.{name}"
|
return f"{server_info.name}.{name}"
|
||||||
|
|
||||||
# --- Test Execution ---
|
# --- Test Execution ---
|
||||||
group = ClientSessionGroup(
|
group = ClientSessionGroup(exit_stack=mock_exit_stack, component_name_hook=name_hook)
|
||||||
exit_stack=mock_exit_stack, component_name_hook=name_hook
|
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
|
||||||
)
|
|
||||||
with mock.patch.object(
|
|
||||||
group, "_establish_session", return_value=(mock_server_info, mock_session)
|
|
||||||
):
|
|
||||||
await group.connect_to_server(StdioServerParameters(command="test"))
|
await group.connect_to_server(StdioServerParameters(command="test"))
|
||||||
|
|
||||||
# --- Assertions ---
|
# --- Assertions ---
|
||||||
@@ -231,9 +219,7 @@ class TestClientSessionGroup:
|
|||||||
# Need a dummy session associated with the existing tool
|
# Need a dummy session associated with the existing tool
|
||||||
mock_session = mock.MagicMock(spec=mcp.ClientSession)
|
mock_session = mock.MagicMock(spec=mcp.ClientSession)
|
||||||
group._tool_to_session[existing_tool_name] = mock_session
|
group._tool_to_session[existing_tool_name] = mock_session
|
||||||
group._session_exit_stacks[mock_session] = mock.Mock(
|
group._session_exit_stacks[mock_session] = mock.Mock(spec=contextlib.AsyncExitStack)
|
||||||
spec=contextlib.AsyncExitStack
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Mock New Connection Attempt ---
|
# --- Mock New Connection Attempt ---
|
||||||
mock_server_info_new = mock.Mock(spec=types.Implementation)
|
mock_server_info_new = mock.Mock(spec=types.Implementation)
|
||||||
@@ -243,9 +229,7 @@ class TestClientSessionGroup:
|
|||||||
# Configure the new session to return a tool with the *same name*
|
# Configure the new session to return a tool with the *same name*
|
||||||
duplicate_tool = mock.Mock(spec=types.Tool)
|
duplicate_tool = mock.Mock(spec=types.Tool)
|
||||||
duplicate_tool.name = existing_tool_name
|
duplicate_tool.name = existing_tool_name
|
||||||
mock_session_new.list_tools.return_value = mock.AsyncMock(
|
mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool])
|
||||||
tools=[duplicate_tool]
|
|
||||||
)
|
|
||||||
# Keep other lists empty for simplicity
|
# Keep other lists empty for simplicity
|
||||||
mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[])
|
mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[])
|
||||||
mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[])
|
mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[])
|
||||||
@@ -266,9 +250,7 @@ class TestClientSessionGroup:
|
|||||||
|
|
||||||
# Verify the duplicate tool was *not* added again (state should be unchanged)
|
# Verify the duplicate tool was *not* added again (state should be unchanged)
|
||||||
assert len(group._tools) == 1 # Should still only have the original
|
assert len(group._tools) == 1 # Should still only have the original
|
||||||
assert (
|
assert group._tools[existing_tool_name] is not duplicate_tool # Ensure it's the original mock
|
||||||
group._tools[existing_tool_name] is not duplicate_tool
|
|
||||||
) # Ensure it's the original mock
|
|
||||||
|
|
||||||
# No patching needed here
|
# No patching needed here
|
||||||
async def test_disconnect_non_existent_server(self):
|
async def test_disconnect_non_existent_server(self):
|
||||||
@@ -292,9 +274,7 @@ class TestClientSessionGroup:
|
|||||||
"mcp.client.session_group.sse_client",
|
"mcp.client.session_group.sse_client",
|
||||||
), # url, headers, timeout, sse_read_timeout
|
), # url, headers, timeout, sse_read_timeout
|
||||||
(
|
(
|
||||||
StreamableHttpParameters(
|
StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False),
|
||||||
url="http://test.com/stream", terminate_on_close=False
|
|
||||||
),
|
|
||||||
"streamablehttp",
|
"streamablehttp",
|
||||||
"mcp.client.session_group.streamablehttp_client",
|
"mcp.client.session_group.streamablehttp_client",
|
||||||
), # url, headers, timeout, sse_read_timeout, terminate_on_close
|
), # url, headers, timeout, sse_read_timeout, terminate_on_close
|
||||||
@@ -306,13 +286,9 @@ class TestClientSessionGroup:
|
|||||||
client_type_name, # Just for clarity or conditional logic if needed
|
client_type_name, # Just for clarity or conditional logic if needed
|
||||||
patch_target_for_client_func,
|
patch_target_for_client_func,
|
||||||
):
|
):
|
||||||
with mock.patch(
|
with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class:
|
||||||
"mcp.client.session_group.mcp.ClientSession"
|
|
||||||
) as mock_ClientSession_class:
|
|
||||||
with mock.patch(patch_target_for_client_func) as mock_specific_client_func:
|
with mock.patch(patch_target_for_client_func) as mock_specific_client_func:
|
||||||
mock_client_cm_instance = mock.AsyncMock(
|
mock_client_cm_instance = mock.AsyncMock(name=f"{client_type_name}ClientCM")
|
||||||
name=f"{client_type_name}ClientCM"
|
|
||||||
)
|
|
||||||
mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read")
|
mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read")
|
||||||
mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write")
|
mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write")
|
||||||
|
|
||||||
@@ -344,9 +320,7 @@ class TestClientSessionGroup:
|
|||||||
|
|
||||||
# Mock session.initialize()
|
# Mock session.initialize()
|
||||||
mock_initialize_result = mock.AsyncMock(name="InitializeResult")
|
mock_initialize_result = mock.AsyncMock(name="InitializeResult")
|
||||||
mock_initialize_result.serverInfo = types.Implementation(
|
mock_initialize_result.serverInfo = types.Implementation(name="foo", version="1")
|
||||||
name="foo", version="1"
|
|
||||||
)
|
|
||||||
mock_entered_session.initialize.return_value = mock_initialize_result
|
mock_entered_session.initialize.return_value = mock_initialize_result
|
||||||
|
|
||||||
# --- Test Execution ---
|
# --- Test Execution ---
|
||||||
@@ -364,9 +338,7 @@ class TestClientSessionGroup:
|
|||||||
# --- Assertions ---
|
# --- Assertions ---
|
||||||
# 1. Assert the correct specific client function was called
|
# 1. Assert the correct specific client function was called
|
||||||
if client_type_name == "stdio":
|
if client_type_name == "stdio":
|
||||||
mock_specific_client_func.assert_called_once_with(
|
mock_specific_client_func.assert_called_once_with(server_params_instance)
|
||||||
server_params_instance
|
|
||||||
)
|
|
||||||
elif client_type_name == "sse":
|
elif client_type_name == "sse":
|
||||||
mock_specific_client_func.assert_called_once_with(
|
mock_specific_client_func.assert_called_once_with(
|
||||||
url=server_params_instance.url,
|
url=server_params_instance.url,
|
||||||
@@ -386,9 +358,7 @@ class TestClientSessionGroup:
|
|||||||
mock_client_cm_instance.__aenter__.assert_awaited_once()
|
mock_client_cm_instance.__aenter__.assert_awaited_once()
|
||||||
|
|
||||||
# 2. Assert ClientSession was called correctly
|
# 2. Assert ClientSession was called correctly
|
||||||
mock_ClientSession_class.assert_called_once_with(
|
mock_ClientSession_class.assert_called_once_with(mock_read_stream, mock_write_stream)
|
||||||
mock_read_stream, mock_write_stream
|
|
||||||
)
|
|
||||||
mock_raw_session_cm.__aenter__.assert_awaited_once()
|
mock_raw_session_cm.__aenter__.assert_awaited_once()
|
||||||
mock_entered_session.initialize.assert_awaited_once()
|
mock_entered_session.initialize.assert_awaited_once()
|
||||||
|
|
||||||
|
|||||||
@@ -50,20 +50,14 @@ async def test_stdio_client():
|
|||||||
break
|
break
|
||||||
|
|
||||||
assert len(read_messages) == 2
|
assert len(read_messages) == 2
|
||||||
assert read_messages[0] == JSONRPCMessage(
|
assert read_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))
|
||||||
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
|
assert read_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}))
|
||||||
)
|
|
||||||
assert read_messages[1] == JSONRPCMessage(
|
|
||||||
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_stdio_client_bad_path():
|
async def test_stdio_client_bad_path():
|
||||||
"""Check that the connection doesn't hang if process errors."""
|
"""Check that the connection doesn't hang if process errors."""
|
||||||
server_params = StdioServerParameters(
|
server_params = StdioServerParameters(command="python", args=["-c", "non-existent-file.py"])
|
||||||
command="python", args=["-c", "non-existent-file.py"]
|
|
||||||
)
|
|
||||||
async with stdio_client(server_params) as (read_stream, write_stream):
|
async with stdio_client(server_params) as (read_stream, write_stream):
|
||||||
async with ClientSession(read_stream, write_stream) as session:
|
async with ClientSession(read_stream, write_stream) as session:
|
||||||
# The session should raise an error when the connection closes
|
# The session should raise an error when the connection closes
|
||||||
|
|||||||
@@ -17,9 +17,7 @@ async def test_list_tools_returns_all_tools():
|
|||||||
f"""Tool number {i}"""
|
f"""Tool number {i}"""
|
||||||
return i
|
return i
|
||||||
|
|
||||||
globals()[f"dummy_tool_{i}"] = (
|
globals()[f"dummy_tool_{i}"] = dummy_tool_func # Keep reference to avoid garbage collection
|
||||||
dummy_tool_func # Keep reference to avoid garbage collection
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get all tools
|
# Get all tools
|
||||||
tools = await mcp.list_tools()
|
tools = await mcp.list_tools()
|
||||||
@@ -30,6 +28,4 @@ async def test_list_tools_returns_all_tools():
|
|||||||
# Verify each tool is unique and has the correct name
|
# Verify each tool is unique and has the correct name
|
||||||
tool_names = [tool.name for tool in tools]
|
tool_names = [tool.name for tool in tools]
|
||||||
expected_names = [f"tool_{i}" for i in range(num_tools)]
|
expected_names = [f"tool_{i}" for i in range(num_tools)]
|
||||||
assert sorted(tool_names) == sorted(
|
assert sorted(tool_names) == sorted(expected_names), "Tool names don't match expected names"
|
||||||
expected_names
|
|
||||||
), "Tool names don't match expected names"
|
|
||||||
|
|||||||
@@ -24,9 +24,7 @@ async def test_resource_templates():
|
|||||||
# Note: list_resource_templates() returns a decorator that wraps the handler
|
# Note: list_resource_templates() returns a decorator that wraps the handler
|
||||||
# The handler returns a ServerResult with a ListResourceTemplatesResult inside
|
# The handler returns a ServerResult with a ListResourceTemplatesResult inside
|
||||||
result = await mcp._mcp_server.request_handlers[types.ListResourceTemplatesRequest](
|
result = await mcp._mcp_server.request_handlers[types.ListResourceTemplatesRequest](
|
||||||
types.ListResourceTemplatesRequest(
|
types.ListResourceTemplatesRequest(method="resources/templates/list", params=None)
|
||||||
method="resources/templates/list", params=None
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
assert isinstance(result.root, types.ListResourceTemplatesResult)
|
assert isinstance(result.root, types.ListResourceTemplatesResult)
|
||||||
templates = result.root.resourceTemplates
|
templates = result.root.resourceTemplates
|
||||||
|
|||||||
@@ -61,9 +61,7 @@ async def test_resource_template_edge_cases():
|
|||||||
await mcp.read_resource("resource://users/123/posts") # Missing post_id
|
await mcp.read_resource("resource://users/123/posts") # Missing post_id
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Unknown resource"):
|
with pytest.raises(ValueError, match="Unknown resource"):
|
||||||
await mcp.read_resource(
|
await mcp.read_resource("resource://users/123/posts/456/extra") # Extra path component
|
||||||
"resource://users/123/posts/456/extra"
|
|
||||||
) # Extra path component
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -110,11 +108,7 @@ async def test_resource_template_client_interaction():
|
|||||||
|
|
||||||
# Verify invalid resource URIs raise appropriate errors
|
# Verify invalid resource URIs raise appropriate errors
|
||||||
with pytest.raises(Exception): # Specific exception type may vary
|
with pytest.raises(Exception): # Specific exception type may vary
|
||||||
await session.read_resource(
|
await session.read_resource(AnyUrl("resource://users/123/posts")) # Missing post_id
|
||||||
AnyUrl("resource://users/123/posts")
|
|
||||||
) # Missing post_id
|
|
||||||
|
|
||||||
with pytest.raises(Exception): # Specific exception type may vary
|
with pytest.raises(Exception): # Specific exception type may vary
|
||||||
await session.read_resource(
|
await session.read_resource(AnyUrl("resource://users/123/invalid")) # Invalid template
|
||||||
AnyUrl("resource://users/123/invalid")
|
|
||||||
) # Invalid template
|
|
||||||
|
|||||||
@@ -45,31 +45,19 @@ async def test_fastmcp_resource_mime_type():
|
|||||||
bytes_resource = mapping["test://image_bytes"]
|
bytes_resource = mapping["test://image_bytes"]
|
||||||
|
|
||||||
# Verify mime types
|
# Verify mime types
|
||||||
assert (
|
assert string_resource.mimeType == "image/png", "String resource mime type not respected"
|
||||||
string_resource.mimeType == "image/png"
|
assert bytes_resource.mimeType == "image/png", "Bytes resource mime type not respected"
|
||||||
), "String resource mime type not respected"
|
|
||||||
assert (
|
|
||||||
bytes_resource.mimeType == "image/png"
|
|
||||||
), "Bytes resource mime type not respected"
|
|
||||||
|
|
||||||
# Also verify the content can be read correctly
|
# Also verify the content can be read correctly
|
||||||
string_result = await client.read_resource(AnyUrl("test://image"))
|
string_result = await client.read_resource(AnyUrl("test://image"))
|
||||||
assert len(string_result.contents) == 1
|
assert len(string_result.contents) == 1
|
||||||
assert (
|
assert getattr(string_result.contents[0], "text") == base64_string, "Base64 string mismatch"
|
||||||
getattr(string_result.contents[0], "text") == base64_string
|
assert string_result.contents[0].mimeType == "image/png", "String content mime type not preserved"
|
||||||
), "Base64 string mismatch"
|
|
||||||
assert (
|
|
||||||
string_result.contents[0].mimeType == "image/png"
|
|
||||||
), "String content mime type not preserved"
|
|
||||||
|
|
||||||
bytes_result = await client.read_resource(AnyUrl("test://image_bytes"))
|
bytes_result = await client.read_resource(AnyUrl("test://image_bytes"))
|
||||||
assert len(bytes_result.contents) == 1
|
assert len(bytes_result.contents) == 1
|
||||||
assert (
|
assert base64.b64decode(getattr(bytes_result.contents[0], "blob")) == image_bytes, "Bytes mismatch"
|
||||||
base64.b64decode(getattr(bytes_result.contents[0], "blob")) == image_bytes
|
assert bytes_result.contents[0].mimeType == "image/png", "Bytes content mime type not preserved"
|
||||||
), "Bytes mismatch"
|
|
||||||
assert (
|
|
||||||
bytes_result.contents[0].mimeType == "image/png"
|
|
||||||
), "Bytes content mime type not preserved"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_lowlevel_resource_mime_type():
|
async def test_lowlevel_resource_mime_type():
|
||||||
@@ -82,9 +70,7 @@ async def test_lowlevel_resource_mime_type():
|
|||||||
|
|
||||||
# Create test resources with specific mime types
|
# Create test resources with specific mime types
|
||||||
test_resources = [
|
test_resources = [
|
||||||
types.Resource(
|
types.Resource(uri=AnyUrl("test://image"), name="test image", mimeType="image/png"),
|
||||||
uri=AnyUrl("test://image"), name="test image", mimeType="image/png"
|
|
||||||
),
|
|
||||||
types.Resource(
|
types.Resource(
|
||||||
uri=AnyUrl("test://image_bytes"),
|
uri=AnyUrl("test://image_bytes"),
|
||||||
name="test image bytes",
|
name="test image bytes",
|
||||||
@@ -101,9 +87,7 @@ async def test_lowlevel_resource_mime_type():
|
|||||||
if str(uri) == "test://image":
|
if str(uri) == "test://image":
|
||||||
return [ReadResourceContents(content=base64_string, mime_type="image/png")]
|
return [ReadResourceContents(content=base64_string, mime_type="image/png")]
|
||||||
elif str(uri) == "test://image_bytes":
|
elif str(uri) == "test://image_bytes":
|
||||||
return [
|
return [ReadResourceContents(content=bytes(image_bytes), mime_type="image/png")]
|
||||||
ReadResourceContents(content=bytes(image_bytes), mime_type="image/png")
|
|
||||||
]
|
|
||||||
raise Exception(f"Resource not found: {uri}")
|
raise Exception(f"Resource not found: {uri}")
|
||||||
|
|
||||||
# Test that resources are listed with correct mime type
|
# Test that resources are listed with correct mime type
|
||||||
@@ -119,28 +103,16 @@ async def test_lowlevel_resource_mime_type():
|
|||||||
bytes_resource = mapping["test://image_bytes"]
|
bytes_resource = mapping["test://image_bytes"]
|
||||||
|
|
||||||
# Verify mime types
|
# Verify mime types
|
||||||
assert (
|
assert string_resource.mimeType == "image/png", "String resource mime type not respected"
|
||||||
string_resource.mimeType == "image/png"
|
assert bytes_resource.mimeType == "image/png", "Bytes resource mime type not respected"
|
||||||
), "String resource mime type not respected"
|
|
||||||
assert (
|
|
||||||
bytes_resource.mimeType == "image/png"
|
|
||||||
), "Bytes resource mime type not respected"
|
|
||||||
|
|
||||||
# Also verify the content can be read correctly
|
# Also verify the content can be read correctly
|
||||||
string_result = await client.read_resource(AnyUrl("test://image"))
|
string_result = await client.read_resource(AnyUrl("test://image"))
|
||||||
assert len(string_result.contents) == 1
|
assert len(string_result.contents) == 1
|
||||||
assert (
|
assert getattr(string_result.contents[0], "text") == base64_string, "Base64 string mismatch"
|
||||||
getattr(string_result.contents[0], "text") == base64_string
|
assert string_result.contents[0].mimeType == "image/png", "String content mime type not preserved"
|
||||||
), "Base64 string mismatch"
|
|
||||||
assert (
|
|
||||||
string_result.contents[0].mimeType == "image/png"
|
|
||||||
), "String content mime type not preserved"
|
|
||||||
|
|
||||||
bytes_result = await client.read_resource(AnyUrl("test://image_bytes"))
|
bytes_result = await client.read_resource(AnyUrl("test://image_bytes"))
|
||||||
assert len(bytes_result.contents) == 1
|
assert len(bytes_result.contents) == 1
|
||||||
assert (
|
assert base64.b64decode(getattr(bytes_result.contents[0], "blob")) == image_bytes, "Bytes mismatch"
|
||||||
base64.b64decode(getattr(bytes_result.contents[0], "blob")) == image_bytes
|
assert bytes_result.contents[0].mimeType == "image/png", "Bytes content mime type not preserved"
|
||||||
), "Bytes mismatch"
|
|
||||||
assert (
|
|
||||||
bytes_result.contents[0].mimeType == "image/png"
|
|
||||||
), "Bytes content mime type not preserved"
|
|
||||||
|
|||||||
@@ -35,15 +35,7 @@ async def test_progress_token_zero_first_call():
|
|||||||
await ctx.report_progress(10, 10) # Complete
|
await ctx.report_progress(10, 10) # Complete
|
||||||
|
|
||||||
# Verify progress notifications
|
# Verify progress notifications
|
||||||
assert (
|
assert mock_session.send_progress_notification.call_count == 3, "All progress notifications should be sent"
|
||||||
mock_session.send_progress_notification.call_count == 3
|
mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=0.0, total=10.0, message=None)
|
||||||
), "All progress notifications should be sent"
|
mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=5.0, total=10.0, message=None)
|
||||||
mock_session.send_progress_notification.assert_any_call(
|
mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=10.0, total=10.0, message=None)
|
||||||
progress_token=0, progress=0.0, total=10.0, message=None
|
|
||||||
)
|
|
||||||
mock_session.send_progress_notification.assert_any_call(
|
|
||||||
progress_token=0, progress=5.0, total=10.0, message=None
|
|
||||||
)
|
|
||||||
mock_session.send_progress_notification.assert_any_call(
|
|
||||||
progress_token=0, progress=10.0, total=10.0, message=None
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -66,9 +66,7 @@ async def test_request_id_match() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req)))
|
await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req)))
|
||||||
response = (
|
response = await server_reader.receive() # Get init response but don't need to check it
|
||||||
await server_reader.receive()
|
|
||||||
) # Get init response but don't need to check it
|
|
||||||
|
|
||||||
# Send initialized notification
|
# Send initialized notification
|
||||||
initialized_notification = JSONRPCNotification(
|
initialized_notification = JSONRPCNotification(
|
||||||
@@ -76,14 +74,10 @@ async def test_request_id_match() -> None:
|
|||||||
params=NotificationParams().model_dump(by_alias=True, exclude_none=True),
|
params=NotificationParams().model_dump(by_alias=True, exclude_none=True),
|
||||||
jsonrpc="2.0",
|
jsonrpc="2.0",
|
||||||
)
|
)
|
||||||
await client_writer.send(
|
await client_writer.send(SessionMessage(JSONRPCMessage(root=initialized_notification)))
|
||||||
SessionMessage(JSONRPCMessage(root=initialized_notification))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send ping request with custom ID
|
# Send ping request with custom ID
|
||||||
ping_request = JSONRPCRequest(
|
ping_request = JSONRPCRequest(id=custom_request_id, method="ping", params={}, jsonrpc="2.0")
|
||||||
id=custom_request_id, method="ping", params={}, jsonrpc="2.0"
|
|
||||||
)
|
|
||||||
|
|
||||||
await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request)))
|
await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request)))
|
||||||
|
|
||||||
@@ -91,9 +85,7 @@ async def test_request_id_match() -> None:
|
|||||||
response = await server_reader.receive()
|
response = await server_reader.receive()
|
||||||
|
|
||||||
# Verify response ID matches request ID
|
# Verify response ID matches request ID
|
||||||
assert (
|
assert response.message.root.id == custom_request_id, "Response ID should match request ID"
|
||||||
response.message.root.id == custom_request_id
|
|
||||||
), "Response ID should match request ID"
|
|
||||||
|
|
||||||
# Cancel server task
|
# Cancel server task
|
||||||
tg.cancel_scope.cancel()
|
tg.cancel_scope.cancel()
|
||||||
|
|||||||
@@ -47,11 +47,7 @@ async def test_server_base64_encoding_issue():
|
|||||||
# Register a resource handler that returns our test data
|
# Register a resource handler that returns our test data
|
||||||
@server.read_resource()
|
@server.read_resource()
|
||||||
async def read_resource(uri: AnyUrl) -> list[ReadResourceContents]:
|
async def read_resource(uri: AnyUrl) -> list[ReadResourceContents]:
|
||||||
return [
|
return [ReadResourceContents(content=binary_data, mime_type="application/octet-stream")]
|
||||||
ReadResourceContents(
|
|
||||||
content=binary_data, mime_type="application/octet-stream"
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Get the handler directly from the server
|
# Get the handler directly from the server
|
||||||
handler = server.request_handlers[ReadResourceRequest]
|
handler = server.request_handlers[ReadResourceRequest]
|
||||||
|
|||||||
@@ -11,12 +11,7 @@ from anyio.abc import TaskStatus
|
|||||||
from mcp.client.session import ClientSession
|
from mcp.client.session import ClientSession
|
||||||
from mcp.server.lowlevel import Server
|
from mcp.server.lowlevel import Server
|
||||||
from mcp.shared.exceptions import McpError
|
from mcp.shared.exceptions import McpError
|
||||||
from mcp.types import (
|
from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent
|
||||||
AudioContent,
|
|
||||||
EmbeddedResource,
|
|
||||||
ImageContent,
|
|
||||||
TextContent,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -36,9 +31,7 @@ async def test_notification_validation_error(tmp_path: Path):
|
|||||||
slow_request_complete = anyio.Event()
|
slow_request_complete = anyio.Event()
|
||||||
|
|
||||||
@server.call_tool()
|
@server.call_tool()
|
||||||
async def slow_tool(
|
async def slow_tool(name: str, arg) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
|
||||||
name: str, arg
|
|
||||||
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
|
|
||||||
nonlocal request_count
|
nonlocal request_count
|
||||||
request_count += 1
|
request_count += 1
|
||||||
|
|
||||||
@@ -75,9 +68,7 @@ async def test_notification_validation_error(tmp_path: Path):
|
|||||||
# - Long enough for fast operations (>10ms)
|
# - Long enough for fast operations (>10ms)
|
||||||
# - Short enough for slow operations (<200ms)
|
# - Short enough for slow operations (<200ms)
|
||||||
# - Not too short to avoid flakiness
|
# - Not too short to avoid flakiness
|
||||||
async with ClientSession(
|
async with ClientSession(read_stream, write_stream, read_timeout_seconds=timedelta(milliseconds=50)) as session:
|
||||||
read_stream, write_stream, read_timeout_seconds=timedelta(milliseconds=50)
|
|
||||||
) as session:
|
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
|
||||||
# First call should work (fast operation)
|
# First call should work (fast operation)
|
||||||
|
|||||||
@@ -23,12 +23,8 @@ async def test_malformed_initialize_request_does_not_crash_server():
|
|||||||
instead of crashing the server (HackerOne #3156202).
|
instead of crashing the server (HackerOne #3156202).
|
||||||
"""
|
"""
|
||||||
# Create in-memory streams for testing
|
# Create in-memory streams for testing
|
||||||
read_send_stream, read_receive_stream = anyio.create_memory_object_stream[
|
read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10)
|
||||||
SessionMessage | Exception
|
write_send_stream, write_receive_stream = anyio.create_memory_object_stream[SessionMessage](10)
|
||||||
](10)
|
|
||||||
write_send_stream, write_receive_stream = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage
|
|
||||||
](10)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create a malformed initialize request (missing required params field)
|
# Create a malformed initialize request (missing required params field)
|
||||||
@@ -78,9 +74,7 @@ async def test_malformed_initialize_request_does_not_crash_server():
|
|||||||
method="tools/call",
|
method="tools/call",
|
||||||
# params=None # Missing required params
|
# params=None # Missing required params
|
||||||
)
|
)
|
||||||
another_request_message = SessionMessage(
|
another_request_message = SessionMessage(message=JSONRPCMessage(another_malformed_request))
|
||||||
message=JSONRPCMessage(another_malformed_request)
|
|
||||||
)
|
|
||||||
|
|
||||||
await read_send_stream.send(another_request_message)
|
await read_send_stream.send(another_request_message)
|
||||||
await anyio.sleep(0.1)
|
await anyio.sleep(0.1)
|
||||||
@@ -109,12 +103,8 @@ async def test_multiple_concurrent_malformed_requests():
|
|||||||
Test that multiple concurrent malformed requests don't crash the server.
|
Test that multiple concurrent malformed requests don't crash the server.
|
||||||
"""
|
"""
|
||||||
# Create in-memory streams for testing
|
# Create in-memory streams for testing
|
||||||
read_send_stream, read_receive_stream = anyio.create_memory_object_stream[
|
read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](100)
|
||||||
SessionMessage | Exception
|
write_send_stream, write_receive_stream = anyio.create_memory_object_stream[SessionMessage](100)
|
||||||
](100)
|
|
||||||
write_send_stream, write_receive_stream = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage
|
|
||||||
](100)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Start a server session
|
# Start a server session
|
||||||
@@ -136,9 +126,7 @@ async def test_multiple_concurrent_malformed_requests():
|
|||||||
method="initialize",
|
method="initialize",
|
||||||
# params=None # Missing required params
|
# params=None # Missing required params
|
||||||
)
|
)
|
||||||
request_message = SessionMessage(
|
request_message = SessionMessage(message=JSONRPCMessage(malformed_request))
|
||||||
message=JSONRPCMessage(malformed_request)
|
|
||||||
)
|
|
||||||
malformed_requests.append(request_message)
|
malformed_requests.append(request_message)
|
||||||
|
|
||||||
# Send all requests
|
# Send all requests
|
||||||
|
|||||||
@@ -116,18 +116,14 @@ def no_expiry_access_token() -> AccessToken:
|
|||||||
class TestBearerAuthBackend:
|
class TestBearerAuthBackend:
|
||||||
"""Tests for the BearerAuthBackend class."""
|
"""Tests for the BearerAuthBackend class."""
|
||||||
|
|
||||||
async def test_no_auth_header(
|
async def test_no_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
|
||||||
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
|
||||||
):
|
|
||||||
"""Test authentication with no Authorization header."""
|
"""Test authentication with no Authorization header."""
|
||||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||||
request = Request({"type": "http", "headers": []})
|
request = Request({"type": "http", "headers": []})
|
||||||
result = await backend.authenticate(request)
|
result = await backend.authenticate(request)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
async def test_non_bearer_auth_header(
|
async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
|
||||||
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
|
||||||
):
|
|
||||||
"""Test authentication with non-Bearer Authorization header."""
|
"""Test authentication with non-Bearer Authorization header."""
|
||||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||||
request = Request(
|
request = Request(
|
||||||
@@ -139,9 +135,7 @@ class TestBearerAuthBackend:
|
|||||||
result = await backend.authenticate(request)
|
result = await backend.authenticate(request)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
async def test_invalid_token(
|
async def test_invalid_token(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
|
||||||
self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
|
||||||
):
|
|
||||||
"""Test authentication with invalid token."""
|
"""Test authentication with invalid token."""
|
||||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||||
request = Request(
|
request = Request(
|
||||||
@@ -160,9 +154,7 @@ class TestBearerAuthBackend:
|
|||||||
):
|
):
|
||||||
"""Test authentication with expired token."""
|
"""Test authentication with expired token."""
|
||||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||||
add_token_to_provider(
|
add_token_to_provider(mock_oauth_provider, "expired_token", expired_access_token)
|
||||||
mock_oauth_provider, "expired_token", expired_access_token
|
|
||||||
)
|
|
||||||
request = Request(
|
request = Request(
|
||||||
{
|
{
|
||||||
"type": "http",
|
"type": "http",
|
||||||
@@ -203,9 +195,7 @@ class TestBearerAuthBackend:
|
|||||||
):
|
):
|
||||||
"""Test authentication with token that has no expiry."""
|
"""Test authentication with token that has no expiry."""
|
||||||
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
backend = BearerAuthBackend(provider=mock_oauth_provider)
|
||||||
add_token_to_provider(
|
add_token_to_provider(mock_oauth_provider, "no_expiry_token", no_expiry_access_token)
|
||||||
mock_oauth_provider, "no_expiry_token", no_expiry_access_token
|
|
||||||
)
|
|
||||||
request = Request(
|
request = Request(
|
||||||
{
|
{
|
||||||
"type": "http",
|
"type": "http",
|
||||||
|
|||||||
@@ -128,16 +128,12 @@ class TestRegistrationErrorHandling:
|
|||||||
|
|
||||||
class TestAuthorizeErrorHandling:
|
class TestAuthorizeErrorHandling:
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_authorize_error_handling(
|
async def test_authorize_error_handling(self, client, oauth_provider, registered_client, pkce_challenge):
|
||||||
self, client, oauth_provider, registered_client, pkce_challenge
|
|
||||||
):
|
|
||||||
# Mock the authorize method to raise an authorize error
|
# Mock the authorize method to raise an authorize error
|
||||||
with unittest.mock.patch.object(
|
with unittest.mock.patch.object(
|
||||||
oauth_provider,
|
oauth_provider,
|
||||||
"authorize",
|
"authorize",
|
||||||
side_effect=AuthorizeError(
|
side_effect=AuthorizeError(error="access_denied", error_description="The user denied the request"),
|
||||||
error="access_denied", error_description="The user denied the request"
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
# Register the client
|
# Register the client
|
||||||
client_id = registered_client["client_id"]
|
client_id = registered_client["client_id"]
|
||||||
@@ -169,9 +165,7 @@ class TestAuthorizeErrorHandling:
|
|||||||
|
|
||||||
class TestTokenErrorHandling:
|
class TestTokenErrorHandling:
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_token_error_handling_auth_code(
|
async def test_token_error_handling_auth_code(self, client, oauth_provider, registered_client, pkce_challenge):
|
||||||
self, client, oauth_provider, registered_client, pkce_challenge
|
|
||||||
):
|
|
||||||
# Register the client and get an auth code
|
# Register the client and get an auth code
|
||||||
client_id = registered_client["client_id"]
|
client_id = registered_client["client_id"]
|
||||||
client_secret = registered_client["client_secret"]
|
client_secret = registered_client["client_secret"]
|
||||||
@@ -224,9 +218,7 @@ class TestTokenErrorHandling:
|
|||||||
assert data["error_description"] == "The authorization code is invalid"
|
assert data["error_description"] == "The authorization code is invalid"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_token_error_handling_refresh_token(
|
async def test_token_error_handling_refresh_token(self, client, oauth_provider, registered_client, pkce_challenge):
|
||||||
self, client, oauth_provider, registered_client, pkce_challenge
|
|
||||||
):
|
|
||||||
# Register the client and get tokens
|
# Register the client and get tokens
|
||||||
client_id = registered_client["client_id"]
|
client_id = registered_client["client_id"]
|
||||||
client_secret = registered_client["client_secret"]
|
client_secret = registered_client["client_secret"]
|
||||||
|
|||||||
@@ -47,9 +47,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
async def register_client(self, client_info: OAuthClientInformationFull):
|
async def register_client(self, client_info: OAuthClientInformationFull):
|
||||||
self.clients[client_info.client_id] = client_info
|
self.clients[client_info.client_id] = client_info
|
||||||
|
|
||||||
async def authorize(
|
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
|
||||||
self, client: OAuthClientInformationFull, params: AuthorizationParams
|
|
||||||
) -> str:
|
|
||||||
# toy authorize implementation which just immediately generates an authorization
|
# toy authorize implementation which just immediately generates an authorization
|
||||||
# code and completes the redirect
|
# code and completes the redirect
|
||||||
code = AuthorizationCode(
|
code = AuthorizationCode(
|
||||||
@@ -63,9 +61,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
)
|
)
|
||||||
self.auth_codes[code.code] = code
|
self.auth_codes[code.code] = code
|
||||||
|
|
||||||
return construct_redirect_uri(
|
return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state)
|
||||||
str(params.redirect_uri), code=code.code, state=params.state
|
|
||||||
)
|
|
||||||
|
|
||||||
async def load_authorization_code(
|
async def load_authorization_code(
|
||||||
self, client: OAuthClientInformationFull, authorization_code: str
|
self, client: OAuthClientInformationFull, authorization_code: str
|
||||||
@@ -102,9 +98,7 @@ class MockOAuthProvider(OAuthAuthorizationServerProvider):
|
|||||||
refresh_token=refresh_token,
|
refresh_token=refresh_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def load_refresh_token(
|
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None:
|
||||||
self, client: OAuthClientInformationFull, refresh_token: str
|
|
||||||
) -> RefreshToken | None:
|
|
||||||
old_access_token = self.refresh_tokens.get(refresh_token)
|
old_access_token = self.refresh_tokens.get(refresh_token)
|
||||||
if old_access_token is None:
|
if old_access_token is None:
|
||||||
return None
|
return None
|
||||||
@@ -224,9 +218,7 @@ def auth_app(mock_oauth_provider):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def test_client(auth_app):
|
async def test_client(auth_app):
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com") as client:
|
||||||
transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com"
|
|
||||||
) as client:
|
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
@@ -261,11 +253,7 @@ async def registered_client(test_client: httpx.AsyncClient, request):
|
|||||||
def pkce_challenge():
|
def pkce_challenge():
|
||||||
"""Create a PKCE challenge with code_verifier and code_challenge."""
|
"""Create a PKCE challenge with code_verifier and code_challenge."""
|
||||||
code_verifier = "some_random_verifier_string"
|
code_verifier = "some_random_verifier_string"
|
||||||
code_challenge = (
|
code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().rstrip("=")
|
||||||
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
|
|
||||||
.decode()
|
|
||||||
.rstrip("=")
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"code_verifier": code_verifier, "code_challenge": code_challenge}
|
return {"code_verifier": code_verifier, "code_challenge": code_challenge}
|
||||||
|
|
||||||
@@ -356,17 +344,13 @@ class TestAuthEndpoints:
|
|||||||
|
|
||||||
metadata = response.json()
|
metadata = response.json()
|
||||||
assert metadata["issuer"] == "https://auth.example.com/"
|
assert metadata["issuer"] == "https://auth.example.com/"
|
||||||
assert (
|
assert metadata["authorization_endpoint"] == "https://auth.example.com/authorize"
|
||||||
metadata["authorization_endpoint"] == "https://auth.example.com/authorize"
|
|
||||||
)
|
|
||||||
assert metadata["token_endpoint"] == "https://auth.example.com/token"
|
assert metadata["token_endpoint"] == "https://auth.example.com/token"
|
||||||
assert metadata["registration_endpoint"] == "https://auth.example.com/register"
|
assert metadata["registration_endpoint"] == "https://auth.example.com/register"
|
||||||
assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke"
|
assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke"
|
||||||
assert metadata["response_types_supported"] == ["code"]
|
assert metadata["response_types_supported"] == ["code"]
|
||||||
assert metadata["code_challenge_methods_supported"] == ["S256"]
|
assert metadata["code_challenge_methods_supported"] == ["S256"]
|
||||||
assert metadata["token_endpoint_auth_methods_supported"] == [
|
assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post"]
|
||||||
"client_secret_post"
|
|
||||||
]
|
|
||||||
assert metadata["grant_types_supported"] == [
|
assert metadata["grant_types_supported"] == [
|
||||||
"authorization_code",
|
"authorization_code",
|
||||||
"refresh_token",
|
"refresh_token",
|
||||||
@@ -386,14 +370,10 @@ class TestAuthEndpoints:
|
|||||||
)
|
)
|
||||||
error_response = response.json()
|
error_response = response.json()
|
||||||
assert error_response["error"] == "invalid_request"
|
assert error_response["error"] == "invalid_request"
|
||||||
assert (
|
assert "error_description" in error_response # Contains validation error messages
|
||||||
"error_description" in error_response
|
|
||||||
) # Contains validation error messages
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_token_invalid_auth_code(
|
async def test_token_invalid_auth_code(self, test_client, registered_client, pkce_challenge):
|
||||||
self, test_client, registered_client, pkce_challenge
|
|
||||||
):
|
|
||||||
"""Test token endpoint error - authorization code does not exist."""
|
"""Test token endpoint error - authorization code does not exist."""
|
||||||
# Try to use a non-existent authorization code
|
# Try to use a non-existent authorization code
|
||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
@@ -413,9 +393,7 @@ class TestAuthEndpoints:
|
|||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
error_response = response.json()
|
error_response = response.json()
|
||||||
assert error_response["error"] == "invalid_grant"
|
assert error_response["error"] == "invalid_grant"
|
||||||
assert (
|
assert "authorization code does not exist" in error_response["error_description"]
|
||||||
"authorization code does not exist" in error_response["error_description"]
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_token_expired_auth_code(
|
async def test_token_expired_auth_code(
|
||||||
@@ -458,9 +436,7 @@ class TestAuthEndpoints:
|
|||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
error_response = response.json()
|
error_response = response.json()
|
||||||
assert error_response["error"] == "invalid_grant"
|
assert error_response["error"] == "invalid_grant"
|
||||||
assert (
|
assert "authorization code has expired" in error_response["error_description"]
|
||||||
"authorization code has expired" in error_response["error_description"]
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -475,9 +451,7 @@ class TestAuthEndpoints:
|
|||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
async def test_token_redirect_uri_mismatch(
|
async def test_token_redirect_uri_mismatch(self, test_client, registered_client, auth_code, pkce_challenge):
|
||||||
self, test_client, registered_client, auth_code, pkce_challenge
|
|
||||||
):
|
|
||||||
"""Test token endpoint error - redirect URI mismatch."""
|
"""Test token endpoint error - redirect URI mismatch."""
|
||||||
# Try to use the code with a different redirect URI
|
# Try to use the code with a different redirect URI
|
||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
@@ -498,9 +472,7 @@ class TestAuthEndpoints:
|
|||||||
assert "redirect_uri did not match" in error_response["error_description"]
|
assert "redirect_uri did not match" in error_response["error_description"]
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_token_code_verifier_mismatch(
|
async def test_token_code_verifier_mismatch(self, test_client, registered_client, auth_code):
|
||||||
self, test_client, registered_client, auth_code
|
|
||||||
):
|
|
||||||
"""Test token endpoint error - PKCE code verifier mismatch."""
|
"""Test token endpoint error - PKCE code verifier mismatch."""
|
||||||
# Try to use the code with an incorrect code verifier
|
# Try to use the code with an incorrect code verifier
|
||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
@@ -569,9 +541,7 @@ class TestAuthEndpoints:
|
|||||||
|
|
||||||
# Step 2: Time travel forward 4 hours (tokens expire in 1 hour by default)
|
# Step 2: Time travel forward 4 hours (tokens expire in 1 hour by default)
|
||||||
# Mock the time.time() function to return a value 4 hours in the future
|
# Mock the time.time() function to return a value 4 hours in the future
|
||||||
with unittest.mock.patch(
|
with unittest.mock.patch("time.time", return_value=current_time + 14400): # 4 hours = 14400 seconds
|
||||||
"time.time", return_value=current_time + 14400
|
|
||||||
): # 4 hours = 14400 seconds
|
|
||||||
# Try to use the refresh token which should now be considered expired
|
# Try to use the refresh token which should now be considered expired
|
||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
"/token",
|
"/token",
|
||||||
@@ -590,9 +560,7 @@ class TestAuthEndpoints:
|
|||||||
assert "refresh token has expired" in error_response["error_description"]
|
assert "refresh token has expired" in error_response["error_description"]
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_token_invalid_scope(
|
async def test_token_invalid_scope(self, test_client, registered_client, auth_code, pkce_challenge):
|
||||||
self, test_client, registered_client, auth_code, pkce_challenge
|
|
||||||
):
|
|
||||||
"""Test token endpoint error - invalid scope in refresh token request."""
|
"""Test token endpoint error - invalid scope in refresh token request."""
|
||||||
# Exchange authorization code for tokens
|
# Exchange authorization code for tokens
|
||||||
token_response = await test_client.post(
|
token_response = await test_client.post(
|
||||||
@@ -628,9 +596,7 @@ class TestAuthEndpoints:
|
|||||||
assert "cannot request scope" in error_response["error_description"]
|
assert "cannot request scope" in error_response["error_description"]
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_registration(
|
async def test_client_registration(self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider):
|
||||||
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider
|
|
||||||
):
|
|
||||||
"""Test client registration."""
|
"""Test client registration."""
|
||||||
client_metadata = {
|
client_metadata = {
|
||||||
"redirect_uris": ["https://client.example.com/callback"],
|
"redirect_uris": ["https://client.example.com/callback"],
|
||||||
@@ -656,9 +622,7 @@ class TestAuthEndpoints:
|
|||||||
# ) is not None
|
# ) is not None
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_registration_missing_required_fields(
|
async def test_client_registration_missing_required_fields(self, test_client: httpx.AsyncClient):
|
||||||
self, test_client: httpx.AsyncClient
|
|
||||||
):
|
|
||||||
"""Test client registration with missing required fields."""
|
"""Test client registration with missing required fields."""
|
||||||
# Missing redirect_uris which is a required field
|
# Missing redirect_uris which is a required field
|
||||||
client_metadata = {
|
client_metadata = {
|
||||||
@@ -677,9 +641,7 @@ class TestAuthEndpoints:
|
|||||||
assert error_data["error_description"] == "redirect_uris: Field required"
|
assert error_data["error_description"] == "redirect_uris: Field required"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_registration_invalid_uri(
|
async def test_client_registration_invalid_uri(self, test_client: httpx.AsyncClient):
|
||||||
self, test_client: httpx.AsyncClient
|
|
||||||
):
|
|
||||||
"""Test client registration with invalid URIs."""
|
"""Test client registration with invalid URIs."""
|
||||||
# Invalid redirect_uri format
|
# Invalid redirect_uri format
|
||||||
client_metadata = {
|
client_metadata = {
|
||||||
@@ -696,14 +658,11 @@ class TestAuthEndpoints:
|
|||||||
assert "error" in error_data
|
assert "error" in error_data
|
||||||
assert error_data["error"] == "invalid_client_metadata"
|
assert error_data["error"] == "invalid_client_metadata"
|
||||||
assert error_data["error_description"] == (
|
assert error_data["error_description"] == (
|
||||||
"redirect_uris.0: Input should be a valid URL, "
|
"redirect_uris.0: Input should be a valid URL, " "relative URL without a base"
|
||||||
"relative URL without a base"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_registration_empty_redirect_uris(
|
async def test_client_registration_empty_redirect_uris(self, test_client: httpx.AsyncClient):
|
||||||
self, test_client: httpx.AsyncClient
|
|
||||||
):
|
|
||||||
"""Test client registration with empty redirect_uris array."""
|
"""Test client registration with empty redirect_uris array."""
|
||||||
client_metadata = {
|
client_metadata = {
|
||||||
"redirect_uris": [], # Empty array
|
"redirect_uris": [], # Empty array
|
||||||
@@ -719,8 +678,7 @@ class TestAuthEndpoints:
|
|||||||
assert "error" in error_data
|
assert "error" in error_data
|
||||||
assert error_data["error"] == "invalid_client_metadata"
|
assert error_data["error"] == "invalid_client_metadata"
|
||||||
assert (
|
assert (
|
||||||
error_data["error_description"]
|
error_data["error_description"] == "redirect_uris: List should have at least 1 item after validation, not 0"
|
||||||
== "redirect_uris: List should have at least 1 item after validation, not 0"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -875,12 +833,7 @@ class TestAuthEndpoints:
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# Verify that the token was revoked
|
# Verify that the token was revoked
|
||||||
assert (
|
assert await mock_oauth_provider.load_access_token(new_token_response["access_token"]) is None
|
||||||
await mock_oauth_provider.load_access_token(
|
|
||||||
new_token_response["access_token"]
|
|
||||||
)
|
|
||||||
is None
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_revoke_invalid_token(self, test_client, registered_client):
|
async def test_revoke_invalid_token(self, test_client, registered_client):
|
||||||
@@ -913,9 +866,7 @@ class TestAuthEndpoints:
|
|||||||
assert "token_type_hint" in error_response["error_description"]
|
assert "token_type_hint" in error_response["error_description"]
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_registration_disallowed_scopes(
|
async def test_client_registration_disallowed_scopes(self, test_client: httpx.AsyncClient):
|
||||||
self, test_client: httpx.AsyncClient
|
|
||||||
):
|
|
||||||
"""Test client registration with scopes that are not allowed."""
|
"""Test client registration with scopes that are not allowed."""
|
||||||
client_metadata = {
|
client_metadata = {
|
||||||
"redirect_uris": ["https://client.example.com/callback"],
|
"redirect_uris": ["https://client.example.com/callback"],
|
||||||
@@ -955,18 +906,14 @@ class TestAuthEndpoints:
|
|||||||
assert client_info["scope"] == "read write"
|
assert client_info["scope"] == "read write"
|
||||||
|
|
||||||
# Retrieve the client from the store to verify default scopes
|
# Retrieve the client from the store to verify default scopes
|
||||||
registered_client = await mock_oauth_provider.get_client(
|
registered_client = await mock_oauth_provider.get_client(client_info["client_id"])
|
||||||
client_info["client_id"]
|
|
||||||
)
|
|
||||||
assert registered_client is not None
|
assert registered_client is not None
|
||||||
|
|
||||||
# Check that default scopes were applied
|
# Check that default scopes were applied
|
||||||
assert registered_client.scope == "read write"
|
assert registered_client.scope == "read write"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_client_registration_invalid_grant_type(
|
async def test_client_registration_invalid_grant_type(self, test_client: httpx.AsyncClient):
|
||||||
self, test_client: httpx.AsyncClient
|
|
||||||
):
|
|
||||||
client_metadata = {
|
client_metadata = {
|
||||||
"redirect_uris": ["https://client.example.com/callback"],
|
"redirect_uris": ["https://client.example.com/callback"],
|
||||||
"client_name": "Test Client",
|
"client_name": "Test Client",
|
||||||
@@ -981,19 +928,14 @@ class TestAuthEndpoints:
|
|||||||
error_data = response.json()
|
error_data = response.json()
|
||||||
assert "error" in error_data
|
assert "error" in error_data
|
||||||
assert error_data["error"] == "invalid_client_metadata"
|
assert error_data["error"] == "invalid_client_metadata"
|
||||||
assert (
|
assert error_data["error_description"] == "grant_types must be authorization_code and refresh_token"
|
||||||
error_data["error_description"]
|
|
||||||
== "grant_types must be authorization_code and refresh_token"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAuthorizeEndpointErrors:
|
class TestAuthorizeEndpointErrors:
|
||||||
"""Test error handling in the OAuth authorization endpoint."""
|
"""Test error handling in the OAuth authorization endpoint."""
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_authorize_missing_client_id(
|
async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, pkce_challenge):
|
||||||
self, test_client: httpx.AsyncClient, pkce_challenge
|
|
||||||
):
|
|
||||||
"""Test authorization endpoint with missing client_id.
|
"""Test authorization endpoint with missing client_id.
|
||||||
|
|
||||||
According to the OAuth2.0 spec, if client_id is missing, the server should
|
According to the OAuth2.0 spec, if client_id is missing, the server should
|
||||||
@@ -1017,9 +959,7 @@ class TestAuthorizeEndpointErrors:
|
|||||||
assert "client_id" in response.text.lower()
|
assert "client_id" in response.text.lower()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_authorize_invalid_client_id(
|
async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, pkce_challenge):
|
||||||
self, test_client: httpx.AsyncClient, pkce_challenge
|
|
||||||
):
|
|
||||||
"""Test authorization endpoint with invalid client_id.
|
"""Test authorization endpoint with invalid client_id.
|
||||||
|
|
||||||
According to the OAuth2.0 spec, if client_id is invalid, the server should
|
According to the OAuth2.0 spec, if client_id is invalid, the server should
|
||||||
@@ -1202,9 +1142,7 @@ class TestAuthorizeEndpointErrors:
|
|||||||
assert query_params["state"][0] == "test_state"
|
assert query_params["state"][0] == "test_state"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_authorize_missing_pkce_challenge(
|
async def test_authorize_missing_pkce_challenge(self, test_client: httpx.AsyncClient, registered_client):
|
||||||
self, test_client: httpx.AsyncClient, registered_client
|
|
||||||
):
|
|
||||||
"""Test authorization endpoint with missing PKCE code_challenge.
|
"""Test authorization endpoint with missing PKCE code_challenge.
|
||||||
|
|
||||||
Missing PKCE parameters should result in invalid_request error.
|
Missing PKCE parameters should result in invalid_request error.
|
||||||
@@ -1233,9 +1171,7 @@ class TestAuthorizeEndpointErrors:
|
|||||||
assert query_params["state"][0] == "test_state"
|
assert query_params["state"][0] == "test_state"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_authorize_invalid_scope(
|
async def test_authorize_invalid_scope(self, test_client: httpx.AsyncClient, registered_client, pkce_challenge):
|
||||||
self, test_client: httpx.AsyncClient, registered_client, pkce_challenge
|
|
||||||
):
|
|
||||||
"""Test authorization endpoint with invalid scope.
|
"""Test authorization endpoint with invalid scope.
|
||||||
|
|
||||||
Invalid scope should redirect with invalid_scope error.
|
Invalid scope should redirect with invalid_scope error.
|
||||||
|
|||||||
@@ -18,9 +18,7 @@ class TestRenderPrompt:
|
|||||||
return "Hello, world!"
|
return "Hello, world!"
|
||||||
|
|
||||||
prompt = Prompt.from_function(fn)
|
prompt = Prompt.from_function(fn)
|
||||||
assert await prompt.render() == [
|
assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))]
|
||||||
UserMessage(content=TextContent(type="text", text="Hello, world!"))
|
|
||||||
]
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_async_fn(self):
|
async def test_async_fn(self):
|
||||||
@@ -28,9 +26,7 @@ class TestRenderPrompt:
|
|||||||
return "Hello, world!"
|
return "Hello, world!"
|
||||||
|
|
||||||
prompt = Prompt.from_function(fn)
|
prompt = Prompt.from_function(fn)
|
||||||
assert await prompt.render() == [
|
assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))]
|
||||||
UserMessage(content=TextContent(type="text", text="Hello, world!"))
|
|
||||||
]
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_fn_with_args(self):
|
async def test_fn_with_args(self):
|
||||||
@@ -39,11 +35,7 @@ class TestRenderPrompt:
|
|||||||
|
|
||||||
prompt = Prompt.from_function(fn)
|
prompt = Prompt.from_function(fn)
|
||||||
assert await prompt.render(arguments={"name": "World"}) == [
|
assert await prompt.render(arguments={"name": "World"}) == [
|
||||||
UserMessage(
|
UserMessage(content=TextContent(type="text", text="Hello, World! You're 30 years old."))
|
||||||
content=TextContent(
|
|
||||||
type="text", text="Hello, World! You're 30 years old."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -61,21 +53,15 @@ class TestRenderPrompt:
|
|||||||
return UserMessage(content="Hello, world!")
|
return UserMessage(content="Hello, world!")
|
||||||
|
|
||||||
prompt = Prompt.from_function(fn)
|
prompt = Prompt.from_function(fn)
|
||||||
assert await prompt.render() == [
|
assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))]
|
||||||
UserMessage(content=TextContent(type="text", text="Hello, world!"))
|
|
||||||
]
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_fn_returns_assistant_message(self):
|
async def test_fn_returns_assistant_message(self):
|
||||||
async def fn() -> AssistantMessage:
|
async def fn() -> AssistantMessage:
|
||||||
return AssistantMessage(
|
return AssistantMessage(content=TextContent(type="text", text="Hello, world!"))
|
||||||
content=TextContent(type="text", text="Hello, world!")
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = Prompt.from_function(fn)
|
prompt = Prompt.from_function(fn)
|
||||||
assert await prompt.render() == [
|
assert await prompt.render() == [AssistantMessage(content=TextContent(type="text", text="Hello, world!"))]
|
||||||
AssistantMessage(content=TextContent(type="text", text="Hello, world!"))
|
|
||||||
]
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_fn_returns_multiple_messages(self):
|
async def test_fn_returns_multiple_messages(self):
|
||||||
@@ -156,9 +142,7 @@ class TestRenderPrompt:
|
|||||||
|
|
||||||
prompt = Prompt.from_function(fn)
|
prompt = Prompt.from_function(fn)
|
||||||
assert await prompt.render() == [
|
assert await prompt.render() == [
|
||||||
UserMessage(
|
UserMessage(content=TextContent(type="text", text="Please analyze this file:")),
|
||||||
content=TextContent(type="text", text="Please analyze this file:")
|
|
||||||
),
|
|
||||||
UserMessage(
|
UserMessage(
|
||||||
content=EmbeddedResource(
|
content=EmbeddedResource(
|
||||||
type="resource",
|
type="resource",
|
||||||
@@ -169,9 +153,7 @@ class TestRenderPrompt:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
AssistantMessage(
|
AssistantMessage(content=TextContent(type="text", text="I'll help analyze that file.")),
|
||||||
content=TextContent(type="text", text="I'll help analyze that file.")
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
|
|||||||
@@ -72,9 +72,7 @@ class TestPromptManager:
|
|||||||
prompt = Prompt.from_function(fn)
|
prompt = Prompt.from_function(fn)
|
||||||
manager.add_prompt(prompt)
|
manager.add_prompt(prompt)
|
||||||
messages = await manager.render_prompt("fn")
|
messages = await manager.render_prompt("fn")
|
||||||
assert messages == [
|
assert messages == [UserMessage(content=TextContent(type="text", text="Hello, world!"))]
|
||||||
UserMessage(content=TextContent(type="text", text="Hello, world!"))
|
|
||||||
]
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_render_prompt_with_args(self):
|
async def test_render_prompt_with_args(self):
|
||||||
@@ -87,9 +85,7 @@ class TestPromptManager:
|
|||||||
prompt = Prompt.from_function(fn)
|
prompt = Prompt.from_function(fn)
|
||||||
manager.add_prompt(prompt)
|
manager.add_prompt(prompt)
|
||||||
messages = await manager.render_prompt("fn", arguments={"name": "World"})
|
messages = await manager.render_prompt("fn", arguments={"name": "World"})
|
||||||
assert messages == [
|
assert messages == [UserMessage(content=TextContent(type="text", text="Hello, World!"))]
|
||||||
UserMessage(content=TextContent(type="text", text="Hello, World!"))
|
|
||||||
]
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_render_unknown_prompt(self):
|
async def test_render_unknown_prompt(self):
|
||||||
|
|||||||
@@ -100,9 +100,7 @@ class TestFileResource:
|
|||||||
with pytest.raises(ValueError, match="Error reading file"):
|
with pytest.raises(ValueError, match="Error reading file"):
|
||||||
await resource.read()
|
await resource.read()
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(os.name == "nt", reason="File permissions behave differently on Windows")
|
||||||
os.name == "nt", reason="File permissions behave differently on Windows"
|
|
||||||
)
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_permission_error(self, temp_file: Path):
|
async def test_permission_error(self, temp_file: Path):
|
||||||
"""Test reading a file without permissions."""
|
"""Test reading a file without permissions."""
|
||||||
|
|||||||
@@ -28,9 +28,7 @@ def complex_arguments_fn(
|
|||||||
# list[str] | str is an interesting case because if it comes in as JSON like
|
# list[str] | str is an interesting case because if it comes in as JSON like
|
||||||
# "[\"a\", \"b\"]" then it will be naively parsed as a string.
|
# "[\"a\", \"b\"]" then it will be naively parsed as a string.
|
||||||
list_str_or_str: list[str] | str,
|
list_str_or_str: list[str] | str,
|
||||||
an_int_annotated_with_field: Annotated[
|
an_int_annotated_with_field: Annotated[int, Field(description="An int with a field")],
|
||||||
int, Field(description="An int with a field")
|
|
||||||
],
|
|
||||||
an_int_annotated_with_field_and_others: Annotated[
|
an_int_annotated_with_field_and_others: Annotated[
|
||||||
int,
|
int,
|
||||||
str, # Should be ignored, really
|
str, # Should be ignored, really
|
||||||
@@ -42,9 +40,7 @@ def complex_arguments_fn(
|
|||||||
"123",
|
"123",
|
||||||
456,
|
456,
|
||||||
],
|
],
|
||||||
field_with_default_via_field_annotation_before_nondefault_arg: Annotated[
|
field_with_default_via_field_annotation_before_nondefault_arg: Annotated[int, Field(1)],
|
||||||
int, Field(1)
|
|
||||||
],
|
|
||||||
unannotated,
|
unannotated,
|
||||||
my_model_a: SomeInputModelA,
|
my_model_a: SomeInputModelA,
|
||||||
my_model_a_forward_ref: "SomeInputModelA",
|
my_model_a_forward_ref: "SomeInputModelA",
|
||||||
@@ -179,9 +175,7 @@ def test_str_vs_list_str():
|
|||||||
def test_skip_names():
|
def test_skip_names():
|
||||||
"""Test that skipped parameters are not included in the model"""
|
"""Test that skipped parameters are not included in the model"""
|
||||||
|
|
||||||
def func_with_many_params(
|
def func_with_many_params(keep_this: int, skip_this: str, also_keep: float, also_skip: bool):
|
||||||
keep_this: int, skip_this: str, also_keep: float, also_skip: bool
|
|
||||||
):
|
|
||||||
return keep_this, skip_this, also_keep, also_skip
|
return keep_this, skip_this, also_keep, also_skip
|
||||||
|
|
||||||
# Skip some parameters
|
# Skip some parameters
|
||||||
|
|||||||
@@ -130,11 +130,7 @@ def make_everything_fastmcp() -> FastMCP:
|
|||||||
|
|
||||||
# Request sampling from the client
|
# Request sampling from the client
|
||||||
result = await ctx.session.create_message(
|
result = await ctx.session.create_message(
|
||||||
messages=[
|
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=prompt))],
|
||||||
SamplingMessage(
|
|
||||||
role="user", content=TextContent(type="text", text=prompt)
|
|
||||||
)
|
|
||||||
],
|
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
)
|
)
|
||||||
@@ -278,11 +274,7 @@ def make_fastmcp_stateless_http_app():
|
|||||||
def run_server(server_port: int) -> None:
|
def run_server(server_port: int) -> None:
|
||||||
"""Run the server."""
|
"""Run the server."""
|
||||||
_, app = make_fastmcp_app()
|
_, app = make_fastmcp_app()
|
||||||
server = uvicorn.Server(
|
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||||
config=uvicorn.Config(
|
|
||||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(f"Starting server on port {server_port}")
|
print(f"Starting server on port {server_port}")
|
||||||
server.run()
|
server.run()
|
||||||
|
|
||||||
@@ -290,11 +282,7 @@ def run_server(server_port: int) -> None:
|
|||||||
def run_everything_legacy_sse_http_server(server_port: int) -> None:
|
def run_everything_legacy_sse_http_server(server_port: int) -> None:
|
||||||
"""Run the comprehensive server with all features."""
|
"""Run the comprehensive server with all features."""
|
||||||
_, app = make_everything_fastmcp_app()
|
_, app = make_everything_fastmcp_app()
|
||||||
server = uvicorn.Server(
|
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||||
config=uvicorn.Config(
|
|
||||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(f"Starting comprehensive server on port {server_port}")
|
print(f"Starting comprehensive server on port {server_port}")
|
||||||
server.run()
|
server.run()
|
||||||
|
|
||||||
@@ -302,11 +290,7 @@ def run_everything_legacy_sse_http_server(server_port: int) -> None:
|
|||||||
def run_streamable_http_server(server_port: int) -> None:
|
def run_streamable_http_server(server_port: int) -> None:
|
||||||
"""Run the StreamableHTTP server."""
|
"""Run the StreamableHTTP server."""
|
||||||
_, app = make_fastmcp_streamable_http_app()
|
_, app = make_fastmcp_streamable_http_app()
|
||||||
server = uvicorn.Server(
|
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||||
config=uvicorn.Config(
|
|
||||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(f"Starting StreamableHTTP server on port {server_port}")
|
print(f"Starting StreamableHTTP server on port {server_port}")
|
||||||
server.run()
|
server.run()
|
||||||
|
|
||||||
@@ -314,11 +298,7 @@ def run_streamable_http_server(server_port: int) -> None:
|
|||||||
def run_everything_server(server_port: int) -> None:
|
def run_everything_server(server_port: int) -> None:
|
||||||
"""Run the comprehensive StreamableHTTP server with all features."""
|
"""Run the comprehensive StreamableHTTP server with all features."""
|
||||||
_, app = make_everything_fastmcp_streamable_http_app()
|
_, app = make_everything_fastmcp_streamable_http_app()
|
||||||
server = uvicorn.Server(
|
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||||
config=uvicorn.Config(
|
|
||||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(f"Starting comprehensive StreamableHTTP server on port {server_port}")
|
print(f"Starting comprehensive StreamableHTTP server on port {server_port}")
|
||||||
server.run()
|
server.run()
|
||||||
|
|
||||||
@@ -326,11 +306,7 @@ def run_everything_server(server_port: int) -> None:
|
|||||||
def run_stateless_http_server(server_port: int) -> None:
|
def run_stateless_http_server(server_port: int) -> None:
|
||||||
"""Run the stateless StreamableHTTP server."""
|
"""Run the stateless StreamableHTTP server."""
|
||||||
_, app = make_fastmcp_stateless_http_app()
|
_, app = make_fastmcp_stateless_http_app()
|
||||||
server = uvicorn.Server(
|
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||||
config=uvicorn.Config(
|
|
||||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(f"Starting stateless StreamableHTTP server on port {server_port}")
|
print(f"Starting stateless StreamableHTTP server on port {server_port}")
|
||||||
server.run()
|
server.run()
|
||||||
|
|
||||||
@@ -369,9 +345,7 @@ def server(server_port: int) -> Generator[None, None, None]:
|
|||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def streamable_http_server(http_server_port: int) -> Generator[None, None, None]:
|
def streamable_http_server(http_server_port: int) -> Generator[None, None, None]:
|
||||||
"""Start the StreamableHTTP server in a separate process."""
|
"""Start the StreamableHTTP server in a separate process."""
|
||||||
proc = multiprocessing.Process(
|
proc = multiprocessing.Process(target=run_streamable_http_server, args=(http_server_port,), daemon=True)
|
||||||
target=run_streamable_http_server, args=(http_server_port,), daemon=True
|
|
||||||
)
|
|
||||||
print("Starting StreamableHTTP server process")
|
print("Starting StreamableHTTP server process")
|
||||||
proc.start()
|
proc.start()
|
||||||
|
|
||||||
@@ -388,9 +362,7 @@ def streamable_http_server(http_server_port: int) -> Generator[None, None, None]
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
attempt += 1
|
attempt += 1
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"StreamableHTTP server failed to start after {max_attempts} attempts")
|
||||||
f"StreamableHTTP server failed to start after {max_attempts} attempts"
|
|
||||||
)
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -427,9 +399,7 @@ def stateless_http_server(
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
attempt += 1
|
attempt += 1
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Stateless server failed to start after {max_attempts} attempts")
|
||||||
f"Stateless server failed to start after {max_attempts} attempts"
|
|
||||||
)
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -459,9 +429,7 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_fastmcp_streamable_http(
|
async def test_fastmcp_streamable_http(streamable_http_server: None, http_server_url: str) -> None:
|
||||||
streamable_http_server: None, http_server_url: str
|
|
||||||
) -> None:
|
|
||||||
"""Test that FastMCP works with StreamableHTTP transport."""
|
"""Test that FastMCP works with StreamableHTTP transport."""
|
||||||
# Connect to the server using StreamableHTTP
|
# Connect to the server using StreamableHTTP
|
||||||
async with streamablehttp_client(http_server_url + "/mcp") as (
|
async with streamablehttp_client(http_server_url + "/mcp") as (
|
||||||
@@ -484,9 +452,7 @@ async def test_fastmcp_streamable_http(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_fastmcp_stateless_streamable_http(
|
async def test_fastmcp_stateless_streamable_http(stateless_http_server: None, stateless_http_server_url: str) -> None:
|
||||||
stateless_http_server: None, stateless_http_server_url: str
|
|
||||||
) -> None:
|
|
||||||
"""Test that FastMCP works with stateless StreamableHTTP transport."""
|
"""Test that FastMCP works with stateless StreamableHTTP transport."""
|
||||||
# Connect to the server using StreamableHTTP
|
# Connect to the server using StreamableHTTP
|
||||||
async with streamablehttp_client(stateless_http_server_url + "/mcp") as (
|
async with streamablehttp_client(stateless_http_server_url + "/mcp") as (
|
||||||
@@ -562,9 +528,7 @@ def everything_server(everything_server_port: int) -> Generator[None, None, None
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
attempt += 1
|
attempt += 1
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Comprehensive server failed to start after {max_attempts} attempts")
|
||||||
f"Comprehensive server failed to start after {max_attempts} attempts"
|
|
||||||
)
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -601,10 +565,7 @@ def everything_streamable_http_server(
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
attempt += 1
|
attempt += 1
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Comprehensive StreamableHTTP server failed to start after " f"{max_attempts} attempts")
|
||||||
f"Comprehensive StreamableHTTP server failed to start after "
|
|
||||||
f"{max_attempts} attempts"
|
|
||||||
)
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -648,9 +609,7 @@ class NotificationCollector:
|
|||||||
await self.handle_tool_list_changed(message.root.params)
|
await self.handle_tool_list_changed(message.root.params)
|
||||||
|
|
||||||
|
|
||||||
async def call_all_mcp_features(
|
async def call_all_mcp_features(session: ClientSession, collector: NotificationCollector) -> None:
|
||||||
session: ClientSession, collector: NotificationCollector
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Test all MCP features using the provided session.
|
Test all MCP features using the provided session.
|
||||||
|
|
||||||
@@ -680,9 +639,7 @@ async def call_all_mcp_features(
|
|||||||
# Test progress callback functionality
|
# Test progress callback functionality
|
||||||
progress_updates = []
|
progress_updates = []
|
||||||
|
|
||||||
async def progress_callback(
|
async def progress_callback(progress: float, total: float | None, message: str | None) -> None:
|
||||||
progress: float, total: float | None, message: str | None
|
|
||||||
) -> None:
|
|
||||||
"""Collect progress updates for testing (async version)."""
|
"""Collect progress updates for testing (async version)."""
|
||||||
progress_updates.append((progress, total, message))
|
progress_updates.append((progress, total, message))
|
||||||
print(f"Progress: {progress}/{total} - {message}")
|
print(f"Progress: {progress}/{total} - {message}")
|
||||||
@@ -726,19 +683,12 @@ async def call_all_mcp_features(
|
|||||||
|
|
||||||
# Verify we received log messages from the sampling tool
|
# Verify we received log messages from the sampling tool
|
||||||
assert len(collector.log_messages) > 0
|
assert len(collector.log_messages) > 0
|
||||||
assert any(
|
assert any("Requesting sampling for prompt" in msg.data for msg in collector.log_messages)
|
||||||
"Requesting sampling for prompt" in msg.data for msg in collector.log_messages
|
assert any("Received sampling result from model" in msg.data for msg in collector.log_messages)
|
||||||
)
|
|
||||||
assert any(
|
|
||||||
"Received sampling result from model" in msg.data
|
|
||||||
for msg in collector.log_messages
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Test notification tool
|
# 4. Test notification tool
|
||||||
notification_message = "test_notifications"
|
notification_message = "test_notifications"
|
||||||
notification_result = await session.call_tool(
|
notification_result = await session.call_tool("notification_tool", {"message": notification_message})
|
||||||
"notification_tool", {"message": notification_message}
|
|
||||||
)
|
|
||||||
assert len(notification_result.content) == 1
|
assert len(notification_result.content) == 1
|
||||||
assert isinstance(notification_result.content[0], TextContent)
|
assert isinstance(notification_result.content[0], TextContent)
|
||||||
assert "Sent notifications and logs" in notification_result.content[0].text
|
assert "Sent notifications and logs" in notification_result.content[0].text
|
||||||
@@ -773,36 +723,24 @@ async def call_all_mcp_features(
|
|||||||
|
|
||||||
# 2. Dynamic resource
|
# 2. Dynamic resource
|
||||||
resource_category = "test"
|
resource_category = "test"
|
||||||
dynamic_content = await session.read_resource(
|
dynamic_content = await session.read_resource(AnyUrl(f"resource://dynamic/{resource_category}"))
|
||||||
AnyUrl(f"resource://dynamic/{resource_category}")
|
|
||||||
)
|
|
||||||
assert isinstance(dynamic_content, ReadResourceResult)
|
assert isinstance(dynamic_content, ReadResourceResult)
|
||||||
assert len(dynamic_content.contents) == 1
|
assert len(dynamic_content.contents) == 1
|
||||||
assert isinstance(dynamic_content.contents[0], TextResourceContents)
|
assert isinstance(dynamic_content.contents[0], TextResourceContents)
|
||||||
assert (
|
assert f"Dynamic resource content for category: {resource_category}" in dynamic_content.contents[0].text
|
||||||
f"Dynamic resource content for category: {resource_category}"
|
|
||||||
in dynamic_content.contents[0].text
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Template resource
|
# 3. Template resource
|
||||||
resource_id = "456"
|
resource_id = "456"
|
||||||
template_content = await session.read_resource(
|
template_content = await session.read_resource(AnyUrl(f"resource://template/{resource_id}/data"))
|
||||||
AnyUrl(f"resource://template/{resource_id}/data")
|
|
||||||
)
|
|
||||||
assert isinstance(template_content, ReadResourceResult)
|
assert isinstance(template_content, ReadResourceResult)
|
||||||
assert len(template_content.contents) == 1
|
assert len(template_content.contents) == 1
|
||||||
assert isinstance(template_content.contents[0], TextResourceContents)
|
assert isinstance(template_content.contents[0], TextResourceContents)
|
||||||
assert (
|
assert f"Template resource data for ID: {resource_id}" in template_content.contents[0].text
|
||||||
f"Template resource data for ID: {resource_id}"
|
|
||||||
in template_content.contents[0].text
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test prompts
|
# Test prompts
|
||||||
# 1. Simple prompt
|
# 1. Simple prompt
|
||||||
prompts = await session.list_prompts()
|
prompts = await session.list_prompts()
|
||||||
simple_prompt = next(
|
simple_prompt = next((p for p in prompts.prompts if p.name == "simple_prompt"), None)
|
||||||
(p for p in prompts.prompts if p.name == "simple_prompt"), None
|
|
||||||
)
|
|
||||||
assert simple_prompt is not None
|
assert simple_prompt is not None
|
||||||
|
|
||||||
prompt_topic = "AI"
|
prompt_topic = "AI"
|
||||||
@@ -812,16 +750,12 @@ async def call_all_mcp_features(
|
|||||||
# The actual message structure depends on the prompt implementation
|
# The actual message structure depends on the prompt implementation
|
||||||
|
|
||||||
# 2. Complex prompt
|
# 2. Complex prompt
|
||||||
complex_prompt = next(
|
complex_prompt = next((p for p in prompts.prompts if p.name == "complex_prompt"), None)
|
||||||
(p for p in prompts.prompts if p.name == "complex_prompt"), None
|
|
||||||
)
|
|
||||||
assert complex_prompt is not None
|
assert complex_prompt is not None
|
||||||
|
|
||||||
query = "What is AI?"
|
query = "What is AI?"
|
||||||
context = "technical"
|
context = "technical"
|
||||||
complex_result = await session.get_prompt(
|
complex_result = await session.get_prompt("complex_prompt", {"user_query": query, "context": context})
|
||||||
"complex_prompt", {"user_query": query, "context": context}
|
|
||||||
)
|
|
||||||
assert isinstance(complex_result, GetPromptResult)
|
assert isinstance(complex_result, GetPromptResult)
|
||||||
assert len(complex_result.messages) >= 1
|
assert len(complex_result.messages) >= 1
|
||||||
|
|
||||||
@@ -837,9 +771,7 @@ async def call_all_mcp_features(
|
|||||||
print(f"Received headers: {headers_data}")
|
print(f"Received headers: {headers_data}")
|
||||||
|
|
||||||
# Test 6: Call tool that returns full context
|
# Test 6: Call tool that returns full context
|
||||||
context_result = await session.call_tool(
|
context_result = await session.call_tool("echo_context", {"custom_request_id": "test-123"})
|
||||||
"echo_context", {"custom_request_id": "test-123"}
|
|
||||||
)
|
|
||||||
assert len(context_result.content) == 1
|
assert len(context_result.content) == 1
|
||||||
assert isinstance(context_result.content[0], TextContent)
|
assert isinstance(context_result.content[0], TextContent)
|
||||||
|
|
||||||
@@ -871,9 +803,7 @@ async def sampling_callback(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_fastmcp_all_features_sse(
|
async def test_fastmcp_all_features_sse(everything_server: None, everything_server_url: str) -> None:
|
||||||
everything_server: None, everything_server_url: str
|
|
||||||
) -> None:
|
|
||||||
"""Test all MCP features work correctly with SSE transport."""
|
"""Test all MCP features work correctly with SSE transport."""
|
||||||
|
|
||||||
# Create notification collector
|
# Create notification collector
|
||||||
|
|||||||
@@ -59,9 +59,7 @@ class TestServer:
|
|||||||
"""Test SSE app creation with different mount paths."""
|
"""Test SSE app creation with different mount paths."""
|
||||||
# Test with default mount path
|
# Test with default mount path
|
||||||
mcp = FastMCP()
|
mcp = FastMCP()
|
||||||
with patch.object(
|
with patch.object(mcp, "_normalize_path", return_value="/messages/") as mock_normalize:
|
||||||
mcp, "_normalize_path", return_value="/messages/"
|
|
||||||
) as mock_normalize:
|
|
||||||
mcp.sse_app()
|
mcp.sse_app()
|
||||||
# Verify _normalize_path was called with correct args
|
# Verify _normalize_path was called with correct args
|
||||||
mock_normalize.assert_called_once_with("/", "/messages/")
|
mock_normalize.assert_called_once_with("/", "/messages/")
|
||||||
@@ -69,18 +67,14 @@ class TestServer:
|
|||||||
# Test with custom mount path in settings
|
# Test with custom mount path in settings
|
||||||
mcp = FastMCP()
|
mcp = FastMCP()
|
||||||
mcp.settings.mount_path = "/custom"
|
mcp.settings.mount_path = "/custom"
|
||||||
with patch.object(
|
with patch.object(mcp, "_normalize_path", return_value="/custom/messages/") as mock_normalize:
|
||||||
mcp, "_normalize_path", return_value="/custom/messages/"
|
|
||||||
) as mock_normalize:
|
|
||||||
mcp.sse_app()
|
mcp.sse_app()
|
||||||
# Verify _normalize_path was called with correct args
|
# Verify _normalize_path was called with correct args
|
||||||
mock_normalize.assert_called_once_with("/custom", "/messages/")
|
mock_normalize.assert_called_once_with("/custom", "/messages/")
|
||||||
|
|
||||||
# Test with mount_path parameter
|
# Test with mount_path parameter
|
||||||
mcp = FastMCP()
|
mcp = FastMCP()
|
||||||
with patch.object(
|
with patch.object(mcp, "_normalize_path", return_value="/param/messages/") as mock_normalize:
|
||||||
mcp, "_normalize_path", return_value="/param/messages/"
|
|
||||||
) as mock_normalize:
|
|
||||||
mcp.sse_app(mount_path="/param")
|
mcp.sse_app(mount_path="/param")
|
||||||
# Verify _normalize_path was called with correct args
|
# Verify _normalize_path was called with correct args
|
||||||
mock_normalize.assert_called_once_with("/param", "/messages/")
|
mock_normalize.assert_called_once_with("/param", "/messages/")
|
||||||
@@ -103,9 +97,7 @@ class TestServer:
|
|||||||
|
|
||||||
# Verify path values
|
# Verify path values
|
||||||
assert sse_routes[0].path == "/sse", "SSE route path should be /sse"
|
assert sse_routes[0].path == "/sse", "SSE route path should be /sse"
|
||||||
assert (
|
assert mount_routes[0].path == "/messages", "Mount route path should be /messages"
|
||||||
mount_routes[0].path == "/messages"
|
|
||||||
), "Mount route path should be /messages"
|
|
||||||
|
|
||||||
# Test with mount path as parameter
|
# Test with mount path as parameter
|
||||||
mcp = FastMCP()
|
mcp = FastMCP()
|
||||||
@@ -121,20 +113,14 @@ class TestServer:
|
|||||||
|
|
||||||
# Verify path values
|
# Verify path values
|
||||||
assert sse_routes[0].path == "/sse", "SSE route path should be /sse"
|
assert sse_routes[0].path == "/sse", "SSE route path should be /sse"
|
||||||
assert (
|
assert mount_routes[0].path == "/messages", "Mount route path should be /messages"
|
||||||
mount_routes[0].path == "/messages"
|
|
||||||
), "Mount route path should be /messages"
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_non_ascii_description(self):
|
async def test_non_ascii_description(self):
|
||||||
"""Test that FastMCP handles non-ASCII characters in descriptions correctly"""
|
"""Test that FastMCP handles non-ASCII characters in descriptions correctly"""
|
||||||
mcp = FastMCP()
|
mcp = FastMCP()
|
||||||
|
|
||||||
@mcp.tool(
|
@mcp.tool(description=("🌟 This tool uses emojis and UTF-8 characters: á é í ó ú ñ 漢字 🎉"))
|
||||||
description=(
|
|
||||||
"🌟 This tool uses emojis and UTF-8 characters: á é í ó ú ñ 漢字 🎉"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
def hello_world(name: str = "世界") -> str:
|
def hello_world(name: str = "世界") -> str:
|
||||||
return f"¡Hola, {name}! 👋"
|
return f"¡Hola, {name}! 👋"
|
||||||
|
|
||||||
@@ -187,9 +173,7 @@ class TestServer:
|
|||||||
async def test_add_resource_decorator_incorrect_usage(self):
|
async def test_add_resource_decorator_incorrect_usage(self):
|
||||||
mcp = FastMCP()
|
mcp = FastMCP()
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(TypeError, match="The @resource decorator was used incorrectly"):
|
||||||
TypeError, match="The @resource decorator was used incorrectly"
|
|
||||||
):
|
|
||||||
|
|
||||||
@mcp.resource # Missing parentheses #type: ignore
|
@mcp.resource # Missing parentheses #type: ignore
|
||||||
def get_data(x: str) -> str:
|
def get_data(x: str) -> str:
|
||||||
@@ -373,9 +357,7 @@ class TestServerResources:
|
|||||||
def get_text():
|
def get_text():
|
||||||
return "Hello, world!"
|
return "Hello, world!"
|
||||||
|
|
||||||
resource = FunctionResource(
|
resource = FunctionResource(uri=AnyUrl("resource://test"), name="test", fn=get_text)
|
||||||
uri=AnyUrl("resource://test"), name="test", fn=get_text
|
|
||||||
)
|
|
||||||
mcp.add_resource(resource)
|
mcp.add_resource(resource)
|
||||||
|
|
||||||
async with client_session(mcp._mcp_server) as client:
|
async with client_session(mcp._mcp_server) as client:
|
||||||
@@ -411,9 +393,7 @@ class TestServerResources:
|
|||||||
text_file = tmp_path / "test.txt"
|
text_file = tmp_path / "test.txt"
|
||||||
text_file.write_text("Hello from file!")
|
text_file.write_text("Hello from file!")
|
||||||
|
|
||||||
resource = FileResource(
|
resource = FileResource(uri=AnyUrl("file://test.txt"), name="test.txt", path=text_file)
|
||||||
uri=AnyUrl("file://test.txt"), name="test.txt", path=text_file
|
|
||||||
)
|
|
||||||
mcp.add_resource(resource)
|
mcp.add_resource(resource)
|
||||||
|
|
||||||
async with client_session(mcp._mcp_server) as client:
|
async with client_session(mcp._mcp_server) as client:
|
||||||
@@ -440,10 +420,7 @@ class TestServerResources:
|
|||||||
async with client_session(mcp._mcp_server) as client:
|
async with client_session(mcp._mcp_server) as client:
|
||||||
result = await client.read_resource(AnyUrl("file://test.bin"))
|
result = await client.read_resource(AnyUrl("file://test.bin"))
|
||||||
assert isinstance(result.contents[0], BlobResourceContents)
|
assert isinstance(result.contents[0], BlobResourceContents)
|
||||||
assert (
|
assert result.contents[0].blob == base64.b64encode(b"Binary file data").decode()
|
||||||
result.contents[0].blob
|
|
||||||
== base64.b64encode(b"Binary file data").decode()
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_function_resource(self):
|
async def test_function_resource(self):
|
||||||
@@ -532,9 +509,7 @@ class TestServerResourceTemplates:
|
|||||||
return f"Data for {org}/{repo}"
|
return f"Data for {org}/{repo}"
|
||||||
|
|
||||||
async with client_session(mcp._mcp_server) as client:
|
async with client_session(mcp._mcp_server) as client:
|
||||||
result = await client.read_resource(
|
result = await client.read_resource(AnyUrl("resource://cursor/fastmcp/data"))
|
||||||
AnyUrl("resource://cursor/fastmcp/data")
|
|
||||||
)
|
|
||||||
assert isinstance(result.contents[0], TextResourceContents)
|
assert isinstance(result.contents[0], TextResourceContents)
|
||||||
assert result.contents[0].text == "Data for cursor/fastmcp"
|
assert result.contents[0].text == "Data for cursor/fastmcp"
|
||||||
|
|
||||||
|
|||||||
@@ -147,9 +147,7 @@ class TestAddTools:
|
|||||||
|
|
||||||
def test_add_lambda_with_no_name(self):
|
def test_add_lambda_with_no_name(self):
|
||||||
manager = ToolManager()
|
manager = ToolManager()
|
||||||
with pytest.raises(
|
with pytest.raises(ValueError, match="You must provide a name for lambda functions"):
|
||||||
ValueError, match="You must provide a name for lambda functions"
|
|
||||||
):
|
|
||||||
manager.add_tool(lambda x: x)
|
manager.add_tool(lambda x: x)
|
||||||
|
|
||||||
def test_warn_on_duplicate_tools(self, caplog):
|
def test_warn_on_duplicate_tools(self, caplog):
|
||||||
@@ -346,9 +344,7 @@ class TestContextHandling:
|
|||||||
tool = manager.add_tool(tool_without_context)
|
tool = manager.add_tool(tool_without_context)
|
||||||
assert tool.context_kwarg is None
|
assert tool.context_kwarg is None
|
||||||
|
|
||||||
def tool_with_parametrized_context(
|
def tool_with_parametrized_context(x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT]) -> str:
|
||||||
x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT]
|
|
||||||
) -> str:
|
|
||||||
return str(x)
|
return str(x)
|
||||||
|
|
||||||
tool = manager.add_tool(tool_with_parametrized_context)
|
tool = manager.add_tool(tool_with_parametrized_context)
|
||||||
|
|||||||
@@ -10,13 +10,7 @@ from mcp.server.models import InitializationOptions
|
|||||||
from mcp.server.session import ServerSession
|
from mcp.server.session import ServerSession
|
||||||
from mcp.shared.message import SessionMessage
|
from mcp.shared.message import SessionMessage
|
||||||
from mcp.shared.session import RequestResponder
|
from mcp.shared.session import RequestResponder
|
||||||
from mcp.types import (
|
from mcp.types import ClientResult, ServerNotification, ServerRequest, Tool, ToolAnnotations
|
||||||
ClientResult,
|
|
||||||
ServerNotification,
|
|
||||||
ServerRequest,
|
|
||||||
Tool,
|
|
||||||
ToolAnnotations,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -45,18 +39,12 @@ async def test_lowlevel_server_tool_annotations():
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10)
|
||||||
SessionMessage
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10)
|
||||||
](10)
|
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage
|
|
||||||
](10)
|
|
||||||
|
|
||||||
# Message handler for client
|
# Message handler for client
|
||||||
async def message_handler(
|
async def message_handler(
|
||||||
message: RequestResponder[ServerRequest, ClientResult]
|
message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception,
|
||||||
| ServerNotification
|
|
||||||
| Exception,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(message, Exception):
|
if isinstance(message, Exception):
|
||||||
raise message
|
raise message
|
||||||
|
|||||||
@@ -56,11 +56,7 @@ async def test_read_resource_binary(temp_file: Path):
|
|||||||
|
|
||||||
@server.read_resource()
|
@server.read_resource()
|
||||||
async def read_resource(uri: AnyUrl) -> Iterable[ReadResourceContents]:
|
async def read_resource(uri: AnyUrl) -> Iterable[ReadResourceContents]:
|
||||||
return [
|
return [ReadResourceContents(content=b"Hello World", mime_type="application/octet-stream")]
|
||||||
ReadResourceContents(
|
|
||||||
content=b"Hello World", mime_type="application/octet-stream"
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Get the handler directly from the server
|
# Get the handler directly from the server
|
||||||
handler = server.request_handlers[types.ReadResourceRequest]
|
handler = server.request_handlers[types.ReadResourceRequest]
|
||||||
|
|||||||
@@ -20,18 +20,12 @@ from mcp.types import (
|
|||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_server_session_initialize():
|
async def test_server_session_initialize():
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
SessionMessage
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
](1)
|
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage
|
|
||||||
](1)
|
|
||||||
|
|
||||||
# Create a message handler to catch exceptions
|
# Create a message handler to catch exceptions
|
||||||
async def message_handler(
|
async def message_handler(
|
||||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||||
| types.ServerNotification
|
|
||||||
| Exception,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(message, Exception):
|
if isinstance(message, Exception):
|
||||||
raise message
|
raise message
|
||||||
@@ -54,9 +48,7 @@ async def test_server_session_initialize():
|
|||||||
if isinstance(message, Exception):
|
if isinstance(message, Exception):
|
||||||
raise message
|
raise message
|
||||||
|
|
||||||
if isinstance(message, ClientNotification) and isinstance(
|
if isinstance(message, ClientNotification) and isinstance(message.root, InitializedNotification):
|
||||||
message.root, InitializedNotification
|
|
||||||
):
|
|
||||||
received_initialized = True
|
received_initialized = True
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -111,12 +103,8 @@ async def test_server_capabilities():
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_server_session_initialize_with_older_protocol_version():
|
async def test_server_session_initialize_with_older_protocol_version():
|
||||||
"""Test that server accepts and responds with older protocol (2024-11-05)."""
|
"""Test that server accepts and responds with older protocol (2024-11-05)."""
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
|
||||||
SessionMessage
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
|
||||||
](1)
|
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage | Exception
|
|
||||||
](1)
|
|
||||||
|
|
||||||
received_initialized = False
|
received_initialized = False
|
||||||
received_protocol_version = None
|
received_protocol_version = None
|
||||||
@@ -137,9 +125,7 @@ async def test_server_session_initialize_with_older_protocol_version():
|
|||||||
if isinstance(message, Exception):
|
if isinstance(message, Exception):
|
||||||
raise message
|
raise message
|
||||||
|
|
||||||
if isinstance(message, types.ClientNotification) and isinstance(
|
if isinstance(message, types.ClientNotification) and isinstance(message.root, InitializedNotification):
|
||||||
message.root, InitializedNotification
|
|
||||||
):
|
|
||||||
received_initialized = True
|
received_initialized = True
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -157,9 +143,7 @@ async def test_server_session_initialize_with_older_protocol_version():
|
|||||||
params=types.InitializeRequestParams(
|
params=types.InitializeRequestParams(
|
||||||
protocolVersion="2024-11-05",
|
protocolVersion="2024-11-05",
|
||||||
capabilities=types.ClientCapabilities(),
|
capabilities=types.ClientCapabilities(),
|
||||||
clientInfo=types.Implementation(
|
clientInfo=types.Implementation(name="test-client", version="1.0.0"),
|
||||||
name="test-client", version="1.0.0"
|
|
||||||
),
|
|
||||||
).model_dump(by_alias=True, mode="json", exclude_none=True),
|
).model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -22,9 +22,10 @@ async def test_stdio_server():
|
|||||||
stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n")
|
stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n")
|
||||||
stdin.seek(0)
|
stdin.seek(0)
|
||||||
|
|
||||||
async with stdio_server(
|
async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as (
|
||||||
stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)
|
read_stream,
|
||||||
) as (read_stream, write_stream):
|
write_stream,
|
||||||
|
):
|
||||||
received_messages = []
|
received_messages = []
|
||||||
async with read_stream:
|
async with read_stream:
|
||||||
async for message in read_stream:
|
async for message in read_stream:
|
||||||
@@ -36,12 +37,8 @@ async def test_stdio_server():
|
|||||||
|
|
||||||
# Verify received messages
|
# Verify received messages
|
||||||
assert len(received_messages) == 2
|
assert len(received_messages) == 2
|
||||||
assert received_messages[0] == JSONRPCMessage(
|
assert received_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))
|
||||||
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
|
assert received_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}))
|
||||||
)
|
|
||||||
assert received_messages[1] == JSONRPCMessage(
|
|
||||||
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test sending responses from the server
|
# Test sending responses from the server
|
||||||
responses = [
|
responses = [
|
||||||
@@ -58,13 +55,7 @@ async def test_stdio_server():
|
|||||||
output_lines = stdout.readlines()
|
output_lines = stdout.readlines()
|
||||||
assert len(output_lines) == 2
|
assert len(output_lines) == 2
|
||||||
|
|
||||||
received_responses = [
|
received_responses = [JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines]
|
||||||
JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines
|
|
||||||
]
|
|
||||||
assert len(received_responses) == 2
|
assert len(received_responses) == 2
|
||||||
assert received_responses[0] == JSONRPCMessage(
|
assert received_responses[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping"))
|
||||||
root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")
|
assert received_responses[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}))
|
||||||
)
|
|
||||||
assert received_responses[1] == JSONRPCMessage(
|
|
||||||
root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -22,10 +22,7 @@ async def test_run_can_only_be_called_once():
|
|||||||
async with manager.run():
|
async with manager.run():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert (
|
assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(excinfo.value)
|
||||||
"StreamableHTTPSessionManager .run() can only be called once per instance"
|
|
||||||
in str(excinfo.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -51,10 +48,7 @@ async def test_run_prevents_concurrent_calls():
|
|||||||
|
|
||||||
# One should succeed, one should fail
|
# One should succeed, one should fail
|
||||||
assert len(errors) == 1
|
assert len(errors) == 1
|
||||||
assert (
|
assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(errors[0])
|
||||||
"StreamableHTTPSessionManager .run() can only be called once per instance"
|
|
||||||
in str(errors[0])
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -76,6 +70,4 @@ async def test_handle_request_without_run_raises_error():
|
|||||||
with pytest.raises(RuntimeError) as excinfo:
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
await manager.handle_request(scope, receive, send)
|
await manager.handle_request(scope, receive, send)
|
||||||
|
|
||||||
assert "Task group is not initialized. Make sure to use run()." in str(
|
assert "Task group is not initialized. Make sure to use run()." in str(excinfo.value)
|
||||||
excinfo.value
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -22,12 +22,8 @@ from mcp.shared.session import (
|
|||||||
async def test_bidirectional_progress_notifications():
|
async def test_bidirectional_progress_notifications():
|
||||||
"""Test that both client and server can send progress notifications."""
|
"""Test that both client and server can send progress notifications."""
|
||||||
# Create memory streams for client/server
|
# Create memory streams for client/server
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5)
|
||||||
SessionMessage
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5)
|
||||||
](5)
|
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage
|
|
||||||
](5)
|
|
||||||
|
|
||||||
# Run a server session so we can send progress updates in tool
|
# Run a server session so we can send progress updates in tool
|
||||||
async def run_server():
|
async def run_server():
|
||||||
@@ -134,9 +130,7 @@ async def test_bidirectional_progress_notifications():
|
|||||||
|
|
||||||
# Client message handler to store progress notifications
|
# Client message handler to store progress notifications
|
||||||
async def handle_client_message(
|
async def handle_client_message(
|
||||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||||
| types.ServerNotification
|
|
||||||
| Exception,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(message, Exception):
|
if isinstance(message, Exception):
|
||||||
raise message
|
raise message
|
||||||
@@ -172,9 +166,7 @@ async def test_bidirectional_progress_notifications():
|
|||||||
await client_session.list_tools()
|
await client_session.list_tools()
|
||||||
|
|
||||||
# Call test_tool with progress token
|
# Call test_tool with progress token
|
||||||
await client_session.call_tool(
|
await client_session.call_tool("test_tool", {"_meta": {"progressToken": client_progress_token}})
|
||||||
"test_tool", {"_meta": {"progressToken": client_progress_token}}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send progress notifications from client to server
|
# Send progress notifications from client to server
|
||||||
await client_session.send_progress_notification(
|
await client_session.send_progress_notification(
|
||||||
@@ -221,12 +213,8 @@ async def test_bidirectional_progress_notifications():
|
|||||||
async def test_progress_context_manager():
|
async def test_progress_context_manager():
|
||||||
"""Test client using progress context manager for sending progress notifications."""
|
"""Test client using progress context manager for sending progress notifications."""
|
||||||
# Create memory streams for client/server
|
# Create memory streams for client/server
|
||||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5)
|
||||||
SessionMessage
|
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5)
|
||||||
](5)
|
|
||||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
|
||||||
SessionMessage
|
|
||||||
](5)
|
|
||||||
|
|
||||||
# Track progress updates
|
# Track progress updates
|
||||||
server_progress_updates = []
|
server_progress_updates = []
|
||||||
@@ -270,9 +258,7 @@ async def test_progress_context_manager():
|
|||||||
|
|
||||||
# Client message handler
|
# Client message handler
|
||||||
async def handle_client_message(
|
async def handle_client_message(
|
||||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||||
| types.ServerNotification
|
|
||||||
| Exception,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(message, Exception):
|
if isinstance(message, Exception):
|
||||||
raise message
|
raise message
|
||||||
|
|||||||
@@ -90,9 +90,7 @@ async def test_request_cancellation():
|
|||||||
ClientRequest(
|
ClientRequest(
|
||||||
types.CallToolRequest(
|
types.CallToolRequest(
|
||||||
method="tools/call",
|
method="tools/call",
|
||||||
params=types.CallToolRequestParams(
|
params=types.CallToolRequestParams(name="slow_tool", arguments={}),
|
||||||
name="slow_tool", arguments={}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
types.CallToolResult,
|
types.CallToolResult,
|
||||||
@@ -103,9 +101,7 @@ async def test_request_cancellation():
|
|||||||
assert "Request cancelled" in str(e)
|
assert "Request cancelled" in str(e)
|
||||||
ev_cancelled.set()
|
ev_cancelled.set()
|
||||||
|
|
||||||
async with create_connected_server_and_client_session(
|
async with create_connected_server_and_client_session(make_server()) as client_session:
|
||||||
make_server()
|
|
||||||
) as client_session:
|
|
||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
tg.start_soon(make_request, client_session)
|
tg.start_soon(make_request, client_session)
|
||||||
|
|
||||||
|
|||||||
@@ -60,11 +60,7 @@ class ServerTest(Server):
|
|||||||
await anyio.sleep(2.0)
|
await anyio.sleep(2.0)
|
||||||
return f"Slow response from {uri.host}"
|
return f"Slow response from {uri.host}"
|
||||||
|
|
||||||
raise McpError(
|
raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found"))
|
||||||
error=ErrorData(
|
|
||||||
code=404, message="OOPS! no resource with that URI was found"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@self.list_tools()
|
@self.list_tools()
|
||||||
async def handle_list_tools() -> list[Tool]:
|
async def handle_list_tools() -> list[Tool]:
|
||||||
@@ -88,12 +84,8 @@ def make_server_app() -> Starlette:
|
|||||||
server = ServerTest()
|
server = ServerTest()
|
||||||
|
|
||||||
async def handle_sse(request: Request) -> Response:
|
async def handle_sse(request: Request) -> Response:
|
||||||
async with sse.connect_sse(
|
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
||||||
request.scope, request.receive, request._send
|
await server.run(streams[0], streams[1], server.create_initialization_options())
|
||||||
) as streams:
|
|
||||||
await server.run(
|
|
||||||
streams[0], streams[1], server.create_initialization_options()
|
|
||||||
)
|
|
||||||
return Response()
|
return Response()
|
||||||
|
|
||||||
app = Starlette(
|
app = Starlette(
|
||||||
@@ -108,11 +100,7 @@ def make_server_app() -> Starlette:
|
|||||||
|
|
||||||
def run_server(server_port: int) -> None:
|
def run_server(server_port: int) -> None:
|
||||||
app = make_server_app()
|
app = make_server_app()
|
||||||
server = uvicorn.Server(
|
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||||
config=uvicorn.Config(
|
|
||||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(f"starting server on {server_port}")
|
print(f"starting server on {server_port}")
|
||||||
server.run()
|
server.run()
|
||||||
|
|
||||||
@@ -124,9 +112,7 @@ def run_server(server_port: int) -> None:
|
|||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def server(server_port: int) -> Generator[None, None, None]:
|
def server(server_port: int) -> Generator[None, None, None]:
|
||||||
proc = multiprocessing.Process(
|
proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True)
|
||||||
target=run_server, kwargs={"server_port": server_port}, daemon=True
|
|
||||||
)
|
|
||||||
print("starting process")
|
print("starting process")
|
||||||
proc.start()
|
proc.start()
|
||||||
|
|
||||||
@@ -171,10 +157,7 @@ async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None:
|
|||||||
async def connection_test() -> None:
|
async def connection_test() -> None:
|
||||||
async with http_client.stream("GET", "/sse") as response:
|
async with http_client.stream("GET", "/sse") as response:
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert (
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||||
response.headers["content-type"]
|
|
||||||
== "text/event-stream; charset=utf-8"
|
|
||||||
)
|
|
||||||
|
|
||||||
line_number = 0
|
line_number = 0
|
||||||
async for line in response.aiter_lines():
|
async for line in response.aiter_lines():
|
||||||
@@ -206,9 +189,7 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def initialized_sse_client_session(
|
async def initialized_sse_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]:
|
||||||
server, server_url: str
|
|
||||||
) -> AsyncGenerator[ClientSession, None]:
|
|
||||||
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
|
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
|
||||||
async with ClientSession(*streams) as session:
|
async with ClientSession(*streams) as session:
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
@@ -236,9 +217,7 @@ async def test_sse_client_exception_handling(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip("this test highlights a possible bug in SSE read timeout exception handling")
|
||||||
"this test highlights a possible bug in SSE read timeout exception handling"
|
|
||||||
)
|
|
||||||
async def test_sse_client_timeout(
|
async def test_sse_client_timeout(
|
||||||
initialized_sse_client_session: ClientSession,
|
initialized_sse_client_session: ClientSession,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -260,11 +239,7 @@ async def test_sse_client_timeout(
|
|||||||
def run_mounted_server(server_port: int) -> None:
|
def run_mounted_server(server_port: int) -> None:
|
||||||
app = make_server_app()
|
app = make_server_app()
|
||||||
main_app = Starlette(routes=[Mount("/mounted_app", app=app)])
|
main_app = Starlette(routes=[Mount("/mounted_app", app=app)])
|
||||||
server = uvicorn.Server(
|
server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||||
config=uvicorn.Config(
|
|
||||||
app=main_app, host="127.0.0.1", port=server_port, log_level="error"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(f"starting server on {server_port}")
|
print(f"starting server on {server_port}")
|
||||||
server.run()
|
server.run()
|
||||||
|
|
||||||
@@ -276,9 +251,7 @@ def run_mounted_server(server_port: int) -> None:
|
|||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def mounted_server(server_port: int) -> Generator[None, None, None]:
|
def mounted_server(server_port: int) -> Generator[None, None, None]:
|
||||||
proc = multiprocessing.Process(
|
proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True)
|
||||||
target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True
|
|
||||||
)
|
|
||||||
print("starting process")
|
print("starting process")
|
||||||
proc.start()
|
proc.start()
|
||||||
|
|
||||||
@@ -308,9 +281,7 @@ def mounted_server(server_port: int) -> Generator[None, None, None]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_sse_client_basic_connection_mounted_app(
|
async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None:
|
||||||
mounted_server: None, server_url: str
|
|
||||||
) -> None:
|
|
||||||
async with sse_client(server_url + "/mounted_app/sse") as streams:
|
async with sse_client(server_url + "/mounted_app/sse") as streams:
|
||||||
async with ClientSession(*streams) as session:
|
async with ClientSession(*streams) as session:
|
||||||
# Test initialization
|
# Test initialization
|
||||||
@@ -372,12 +343,8 @@ def run_context_server(server_port: int) -> None:
|
|||||||
context_server = RequestContextServer()
|
context_server = RequestContextServer()
|
||||||
|
|
||||||
async def handle_sse(request: Request) -> Response:
|
async def handle_sse(request: Request) -> Response:
|
||||||
async with sse.connect_sse(
|
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
||||||
request.scope, request.receive, request._send
|
await context_server.run(streams[0], streams[1], context_server.create_initialization_options())
|
||||||
) as streams:
|
|
||||||
await context_server.run(
|
|
||||||
streams[0], streams[1], context_server.create_initialization_options()
|
|
||||||
)
|
|
||||||
return Response()
|
return Response()
|
||||||
|
|
||||||
app = Starlette(
|
app = Starlette(
|
||||||
@@ -387,11 +354,7 @@ def run_context_server(server_port: int) -> None:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
server = uvicorn.Server(
|
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||||
config=uvicorn.Config(
|
|
||||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(f"starting context server on {server_port}")
|
print(f"starting context server on {server_port}")
|
||||||
server.run()
|
server.run()
|
||||||
|
|
||||||
@@ -399,9 +362,7 @@ def run_context_server(server_port: int) -> None:
|
|||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def context_server(server_port: int) -> Generator[None, None, None]:
|
def context_server(server_port: int) -> Generator[None, None, None]:
|
||||||
"""Fixture that provides a server with request context capture"""
|
"""Fixture that provides a server with request context capture"""
|
||||||
proc = multiprocessing.Process(
|
proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True)
|
||||||
target=run_context_server, kwargs={"server_port": server_port}, daemon=True
|
|
||||||
)
|
|
||||||
print("starting context server process")
|
print("starting context server process")
|
||||||
proc.start()
|
proc.start()
|
||||||
|
|
||||||
@@ -418,9 +379,7 @@ def context_server(server_port: int) -> Generator[None, None, None]:
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
attempt += 1
|
attempt += 1
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Context server failed to start after {max_attempts} attempts")
|
||||||
f"Context server failed to start after {max_attempts} attempts"
|
|
||||||
)
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -432,9 +391,7 @@ def context_server(server_port: int) -> Generator[None, None, None]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_request_context_propagation(
|
async def test_request_context_propagation(context_server: None, server_url: str) -> None:
|
||||||
context_server: None, server_url: str
|
|
||||||
) -> None:
|
|
||||||
"""Test that request context is properly propagated through SSE transport."""
|
"""Test that request context is properly propagated through SSE transport."""
|
||||||
# Test with custom headers
|
# Test with custom headers
|
||||||
custom_headers = {
|
custom_headers = {
|
||||||
@@ -458,11 +415,7 @@ async def test_request_context_propagation(
|
|||||||
# Parse the JSON response
|
# Parse the JSON response
|
||||||
|
|
||||||
assert len(tool_result.content) == 1
|
assert len(tool_result.content) == 1
|
||||||
headers_data = json.loads(
|
headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}")
|
||||||
tool_result.content[0].text
|
|
||||||
if tool_result.content[0].type == "text"
|
|
||||||
else "{}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify headers were propagated
|
# Verify headers were propagated
|
||||||
assert headers_data.get("authorization") == "Bearer test-token"
|
assert headers_data.get("authorization") == "Bearer test-token"
|
||||||
@@ -487,15 +440,11 @@ async def test_request_context_isolation(context_server: None, server_url: str)
|
|||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
|
||||||
# Call the tool that echoes context
|
# Call the tool that echoes context
|
||||||
tool_result = await session.call_tool(
|
tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"})
|
||||||
"echo_context", {"request_id": f"request-{i}"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(tool_result.content) == 1
|
assert len(tool_result.content) == 1
|
||||||
context_data = json.loads(
|
context_data = json.loads(
|
||||||
tool_result.content[0].text
|
tool_result.content[0].text if tool_result.content[0].type == "text" else "{}"
|
||||||
if tool_result.content[0].type == "text"
|
|
||||||
else "{}"
|
|
||||||
)
|
)
|
||||||
contexts.append(context_data)
|
contexts.append(context_data)
|
||||||
|
|
||||||
@@ -514,8 +463,4 @@ def test_sse_message_id_coercion():
|
|||||||
"""
|
"""
|
||||||
json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}'
|
json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}'
|
||||||
msg = types.JSONRPCMessage.model_validate_json(json_message)
|
msg = types.JSONRPCMessage.model_validate_json(json_message)
|
||||||
assert msg == snapshot(
|
assert msg == snapshot(types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123)))
|
||||||
types.JSONRPCMessage(
|
|
||||||
root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -72,9 +72,7 @@ class SimpleEventStore(EventStore):
|
|||||||
self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = []
|
self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = []
|
||||||
self._event_id_counter = 0
|
self._event_id_counter = 0
|
||||||
|
|
||||||
async def store_event(
|
async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage) -> EventId:
|
||||||
self, stream_id: StreamId, message: types.JSONRPCMessage
|
|
||||||
) -> EventId:
|
|
||||||
"""Store an event and return its ID."""
|
"""Store an event and return its ID."""
|
||||||
self._event_id_counter += 1
|
self._event_id_counter += 1
|
||||||
event_id = str(self._event_id_counter)
|
event_id = str(self._event_id_counter)
|
||||||
@@ -156,9 +154,7 @@ class ServerTest(Server):
|
|||||||
|
|
||||||
# When the tool is called, send a notification to test GET stream
|
# When the tool is called, send a notification to test GET stream
|
||||||
if name == "test_tool_with_standalone_notification":
|
if name == "test_tool_with_standalone_notification":
|
||||||
await ctx.session.send_resource_updated(
|
await ctx.session.send_resource_updated(uri=AnyUrl("http://test_resource"))
|
||||||
uri=AnyUrl("http://test_resource")
|
|
||||||
)
|
|
||||||
return [TextContent(type="text", text=f"Called {name}")]
|
return [TextContent(type="text", text=f"Called {name}")]
|
||||||
|
|
||||||
elif name == "long_running_with_checkpoints":
|
elif name == "long_running_with_checkpoints":
|
||||||
@@ -189,9 +185,7 @@ class ServerTest(Server):
|
|||||||
messages=[
|
messages=[
|
||||||
types.SamplingMessage(
|
types.SamplingMessage(
|
||||||
role="user",
|
role="user",
|
||||||
content=types.TextContent(
|
content=types.TextContent(type="text", text="Server needs client sampling"),
|
||||||
type="text", text="Server needs client sampling"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
@@ -199,11 +193,7 @@ class ServerTest(Server):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Return the sampling result in the tool response
|
# Return the sampling result in the tool response
|
||||||
response = (
|
response = sampling_result.content.text if sampling_result.content.type == "text" else None
|
||||||
sampling_result.content.text
|
|
||||||
if sampling_result.content.type == "text"
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
return [
|
return [
|
||||||
TextContent(
|
TextContent(
|
||||||
type="text",
|
type="text",
|
||||||
@@ -214,9 +204,7 @@ class ServerTest(Server):
|
|||||||
return [TextContent(type="text", text=f"Called {name}")]
|
return [TextContent(type="text", text=f"Called {name}")]
|
||||||
|
|
||||||
|
|
||||||
def create_app(
|
def create_app(is_json_response_enabled=False, event_store: EventStore | None = None) -> Starlette:
|
||||||
is_json_response_enabled=False, event_store: EventStore | None = None
|
|
||||||
) -> Starlette:
|
|
||||||
"""Create a Starlette application for testing using the session manager.
|
"""Create a Starlette application for testing using the session manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -245,9 +233,7 @@ def create_app(
|
|||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def run_server(
|
def run_server(port: int, is_json_response_enabled=False, event_store: EventStore | None = None) -> None:
|
||||||
port: int, is_json_response_enabled=False, event_store: EventStore | None = None
|
|
||||||
) -> None:
|
|
||||||
"""Run the test server.
|
"""Run the test server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -300,9 +286,7 @@ def json_server_port() -> int:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def basic_server(basic_server_port: int) -> Generator[None, None, None]:
|
def basic_server(basic_server_port: int) -> Generator[None, None, None]:
|
||||||
"""Start a basic server."""
|
"""Start a basic server."""
|
||||||
proc = multiprocessing.Process(
|
proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True)
|
||||||
target=run_server, kwargs={"port": basic_server_port}, daemon=True
|
|
||||||
)
|
|
||||||
proc.start()
|
proc.start()
|
||||||
|
|
||||||
# Wait for server to be running
|
# Wait for server to be running
|
||||||
@@ -778,9 +762,7 @@ async def test_streamablehttp_client_basic_connection(basic_server, basic_server
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_streamablehttp_client_resource_read(initialized_client_session):
|
async def test_streamablehttp_client_resource_read(initialized_client_session):
|
||||||
"""Test client resource read functionality."""
|
"""Test client resource read functionality."""
|
||||||
response = await initialized_client_session.read_resource(
|
response = await initialized_client_session.read_resource(uri=AnyUrl("foobar://test-resource"))
|
||||||
uri=AnyUrl("foobar://test-resource")
|
|
||||||
)
|
|
||||||
assert len(response.contents) == 1
|
assert len(response.contents) == 1
|
||||||
assert response.contents[0].uri == AnyUrl("foobar://test-resource")
|
assert response.contents[0].uri == AnyUrl("foobar://test-resource")
|
||||||
assert response.contents[0].text == "Read test-resource"
|
assert response.contents[0].text == "Read test-resource"
|
||||||
@@ -805,17 +787,13 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
|
|||||||
async def test_streamablehttp_client_error_handling(initialized_client_session):
|
async def test_streamablehttp_client_error_handling(initialized_client_session):
|
||||||
"""Test error handling in client."""
|
"""Test error handling in client."""
|
||||||
with pytest.raises(McpError) as exc_info:
|
with pytest.raises(McpError) as exc_info:
|
||||||
await initialized_client_session.read_resource(
|
await initialized_client_session.read_resource(uri=AnyUrl("unknown://test-error"))
|
||||||
uri=AnyUrl("unknown://test-error")
|
|
||||||
)
|
|
||||||
assert exc_info.value.error.code == 0
|
assert exc_info.value.error.code == 0
|
||||||
assert "Unknown resource: unknown://test-error" in exc_info.value.error.message
|
assert "Unknown resource: unknown://test-error" in exc_info.value.error.message
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_streamablehttp_client_session_persistence(
|
async def test_streamablehttp_client_session_persistence(basic_server, basic_server_url):
|
||||||
basic_server, basic_server_url
|
|
||||||
):
|
|
||||||
"""Test that session ID persists across requests."""
|
"""Test that session ID persists across requests."""
|
||||||
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
|
||||||
read_stream,
|
read_stream,
|
||||||
@@ -843,9 +821,7 @@ async def test_streamablehttp_client_session_persistence(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_streamablehttp_client_json_response(
|
async def test_streamablehttp_client_json_response(json_response_server, json_server_url):
|
||||||
json_response_server, json_server_url
|
|
||||||
):
|
|
||||||
"""Test client with JSON response mode."""
|
"""Test client with JSON response mode."""
|
||||||
async with streamablehttp_client(f"{json_server_url}/mcp") as (
|
async with streamablehttp_client(f"{json_server_url}/mcp") as (
|
||||||
read_stream,
|
read_stream,
|
||||||
@@ -882,9 +858,7 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
|
|||||||
|
|
||||||
# Define message handler to capture notifications
|
# Define message handler to capture notifications
|
||||||
async def message_handler(
|
async def message_handler(
|
||||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||||
| types.ServerNotification
|
|
||||||
| Exception,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(message, types.ServerNotification):
|
if isinstance(message, types.ServerNotification):
|
||||||
notifications_received.append(message)
|
notifications_received.append(message)
|
||||||
@@ -894,9 +868,7 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
|
|||||||
write_stream,
|
write_stream,
|
||||||
_,
|
_,
|
||||||
):
|
):
|
||||||
async with ClientSession(
|
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
|
||||||
read_stream, write_stream, message_handler=message_handler
|
|
||||||
) as session:
|
|
||||||
# Initialize the session - this triggers the GET stream setup
|
# Initialize the session - this triggers the GET stream setup
|
||||||
result = await session.initialize()
|
result = await session.initialize()
|
||||||
assert isinstance(result, InitializeResult)
|
assert isinstance(result, InitializeResult)
|
||||||
@@ -914,15 +886,11 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url):
|
|||||||
assert str(notif.root.params.uri) == "http://test_resource/"
|
assert str(notif.root.params.uri) == "http://test_resource/"
|
||||||
resource_update_found = True
|
resource_update_found = True
|
||||||
|
|
||||||
assert (
|
assert resource_update_found, "ResourceUpdatedNotification not received via GET stream"
|
||||||
resource_update_found
|
|
||||||
), "ResourceUpdatedNotification not received via GET stream"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_streamablehttp_client_session_termination(
|
async def test_streamablehttp_client_session_termination(basic_server, basic_server_url):
|
||||||
basic_server, basic_server_url
|
|
||||||
):
|
|
||||||
"""Test client session termination functionality."""
|
"""Test client session termination functionality."""
|
||||||
|
|
||||||
captured_session_id = None
|
captured_session_id = None
|
||||||
@@ -963,9 +931,7 @@ async def test_streamablehttp_client_session_termination(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_streamablehttp_client_session_termination_204(
|
async def test_streamablehttp_client_session_termination_204(basic_server, basic_server_url, monkeypatch):
|
||||||
basic_server, basic_server_url, monkeypatch
|
|
||||||
):
|
|
||||||
"""Test client session termination functionality with a 204 response.
|
"""Test client session termination functionality with a 204 response.
|
||||||
|
|
||||||
This test patches the httpx client to return a 204 response for DELETEs.
|
This test patches the httpx client to return a 204 response for DELETEs.
|
||||||
@@ -1040,9 +1006,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
|||||||
tool_started = False
|
tool_started = False
|
||||||
|
|
||||||
async def message_handler(
|
async def message_handler(
|
||||||
message: RequestResponder[types.ServerRequest, types.ClientResult]
|
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||||
| types.ServerNotification
|
|
||||||
| Exception,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(message, types.ServerNotification):
|
if isinstance(message, types.ServerNotification):
|
||||||
captured_notifications.append(message)
|
captured_notifications.append(message)
|
||||||
@@ -1062,9 +1026,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
|||||||
write_stream,
|
write_stream,
|
||||||
get_session_id,
|
get_session_id,
|
||||||
):
|
):
|
||||||
async with ClientSession(
|
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
|
||||||
read_stream, write_stream, message_handler=message_handler
|
|
||||||
) as session:
|
|
||||||
# Initialize the session
|
# Initialize the session
|
||||||
result = await session.initialize()
|
result = await session.initialize()
|
||||||
assert isinstance(result, InitializeResult)
|
assert isinstance(result, InitializeResult)
|
||||||
@@ -1082,9 +1044,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
|||||||
types.ClientRequest(
|
types.ClientRequest(
|
||||||
types.CallToolRequest(
|
types.CallToolRequest(
|
||||||
method="tools/call",
|
method="tools/call",
|
||||||
params=types.CallToolRequestParams(
|
params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
|
||||||
name="long_running_with_checkpoints", arguments={}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
types.CallToolResult,
|
types.CallToolResult,
|
||||||
@@ -1114,9 +1074,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
|||||||
write_stream,
|
write_stream,
|
||||||
_,
|
_,
|
||||||
):
|
):
|
||||||
async with ClientSession(
|
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
|
||||||
read_stream, write_stream, message_handler=message_handler
|
|
||||||
) as session:
|
|
||||||
# Don't initialize - just use the existing session
|
# Don't initialize - just use the existing session
|
||||||
|
|
||||||
# Resume the tool with the resumption token
|
# Resume the tool with the resumption token
|
||||||
@@ -1129,9 +1087,7 @@ async def test_streamablehttp_client_resumption(event_server):
|
|||||||
types.ClientRequest(
|
types.ClientRequest(
|
||||||
types.CallToolRequest(
|
types.CallToolRequest(
|
||||||
method="tools/call",
|
method="tools/call",
|
||||||
params=types.CallToolRequestParams(
|
params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
|
||||||
name="long_running_with_checkpoints", arguments={}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
types.CallToolResult,
|
types.CallToolResult,
|
||||||
@@ -1149,14 +1105,11 @@ async def test_streamablehttp_client_resumption(event_server):
|
|||||||
# Should not have the first notification
|
# Should not have the first notification
|
||||||
# Check that "Tool started" notification isn't repeated when resuming
|
# Check that "Tool started" notification isn't repeated when resuming
|
||||||
assert not any(
|
assert not any(
|
||||||
isinstance(n.root, types.LoggingMessageNotification)
|
isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started"
|
||||||
and n.root.params.data == "Tool started"
|
|
||||||
for n in captured_notifications
|
for n in captured_notifications
|
||||||
)
|
)
|
||||||
# there is no intersection between pre and post notifications
|
# there is no intersection between pre and post notifications
|
||||||
assert not any(
|
assert not any(n in captured_notifications_pre for n in captured_notifications)
|
||||||
n in captured_notifications_pre for n in captured_notifications
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -1175,11 +1128,7 @@ async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
|
|||||||
nonlocal sampling_callback_invoked, captured_message_params
|
nonlocal sampling_callback_invoked, captured_message_params
|
||||||
sampling_callback_invoked = True
|
sampling_callback_invoked = True
|
||||||
captured_message_params = params
|
captured_message_params = params
|
||||||
message_received = (
|
message_received = params.messages[0].content.text if params.messages[0].content.type == "text" else None
|
||||||
params.messages[0].content.text
|
|
||||||
if params.messages[0].content.type == "text"
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
return types.CreateMessageResult(
|
return types.CreateMessageResult(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
@@ -1212,19 +1161,13 @@ async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
|
|||||||
# Verify the tool result contains the expected content
|
# Verify the tool result contains the expected content
|
||||||
assert len(tool_result.content) == 1
|
assert len(tool_result.content) == 1
|
||||||
assert tool_result.content[0].type == "text"
|
assert tool_result.content[0].type == "text"
|
||||||
assert (
|
assert "Response from sampling: Received message from server" in tool_result.content[0].text
|
||||||
"Response from sampling: Received message from server"
|
|
||||||
in tool_result.content[0].text
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify sampling callback was invoked
|
# Verify sampling callback was invoked
|
||||||
assert sampling_callback_invoked
|
assert sampling_callback_invoked
|
||||||
assert captured_message_params is not None
|
assert captured_message_params is not None
|
||||||
assert len(captured_message_params.messages) == 1
|
assert len(captured_message_params.messages) == 1
|
||||||
assert (
|
assert captured_message_params.messages[0].content.text == "Server needs client sampling"
|
||||||
captured_message_params.messages[0].content.text
|
|
||||||
== "Server needs client sampling"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Context-aware server implementation for testing request context propagation
|
# Context-aware server implementation for testing request context propagation
|
||||||
@@ -1325,9 +1268,7 @@ def run_context_aware_server(port: int):
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
|
def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
|
||||||
"""Start the context-aware server in a separate process."""
|
"""Start the context-aware server in a separate process."""
|
||||||
proc = multiprocessing.Process(
|
proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True)
|
||||||
target=run_context_aware_server, args=(basic_server_port,), daemon=True
|
|
||||||
)
|
|
||||||
proc.start()
|
proc.start()
|
||||||
|
|
||||||
# Wait for server to be running
|
# Wait for server to be running
|
||||||
@@ -1342,9 +1283,7 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
attempt += 1
|
attempt += 1
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Context-aware server failed to start after {max_attempts} attempts")
|
||||||
f"Context-aware server failed to start after {max_attempts} attempts"
|
|
||||||
)
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -1355,9 +1294,7 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_streamablehttp_request_context_propagation(
|
async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None:
|
||||||
context_aware_server: None, basic_server_url: str
|
|
||||||
) -> None:
|
|
||||||
"""Test that request context is properly propagated through StreamableHTTP."""
|
"""Test that request context is properly propagated through StreamableHTTP."""
|
||||||
custom_headers = {
|
custom_headers = {
|
||||||
"Authorization": "Bearer test-token",
|
"Authorization": "Bearer test-token",
|
||||||
@@ -1365,9 +1302,11 @@ async def test_streamablehttp_request_context_propagation(
|
|||||||
"X-Trace-Id": "trace-123",
|
"X-Trace-Id": "trace-123",
|
||||||
}
|
}
|
||||||
|
|
||||||
async with streamablehttp_client(
|
async with streamablehttp_client(f"{basic_server_url}/mcp", headers=custom_headers) as (
|
||||||
f"{basic_server_url}/mcp", headers=custom_headers
|
read_stream,
|
||||||
) as (read_stream, write_stream, _):
|
write_stream,
|
||||||
|
_,
|
||||||
|
):
|
||||||
async with ClientSession(read_stream, write_stream) as session:
|
async with ClientSession(read_stream, write_stream) as session:
|
||||||
result = await session.initialize()
|
result = await session.initialize()
|
||||||
assert isinstance(result, InitializeResult)
|
assert isinstance(result, InitializeResult)
|
||||||
@@ -1388,9 +1327,7 @@ async def test_streamablehttp_request_context_propagation(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_streamablehttp_request_context_isolation(
|
async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None:
|
||||||
context_aware_server: None, basic_server_url: str
|
|
||||||
) -> None:
|
|
||||||
"""Test that request contexts are isolated between StreamableHTTP clients."""
|
"""Test that request contexts are isolated between StreamableHTTP clients."""
|
||||||
contexts = []
|
contexts = []
|
||||||
|
|
||||||
@@ -1402,16 +1339,12 @@ async def test_streamablehttp_request_context_isolation(
|
|||||||
"Authorization": f"Bearer token-{i}",
|
"Authorization": f"Bearer token-{i}",
|
||||||
}
|
}
|
||||||
|
|
||||||
async with streamablehttp_client(
|
async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as (read_stream, write_stream, _):
|
||||||
f"{basic_server_url}/mcp", headers=headers
|
|
||||||
) as (read_stream, write_stream, _):
|
|
||||||
async with ClientSession(read_stream, write_stream) as session:
|
async with ClientSession(read_stream, write_stream) as session:
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
|
||||||
# Call the tool that echoes context
|
# Call the tool that echoes context
|
||||||
tool_result = await session.call_tool(
|
tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"})
|
||||||
"echo_context", {"request_id": f"request-{i}"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(tool_result.content) == 1
|
assert len(tool_result.content) == 1
|
||||||
assert isinstance(tool_result.content[0], TextContent)
|
assert isinstance(tool_result.content[0], TextContent)
|
||||||
|
|||||||
@@ -54,11 +54,7 @@ class ServerTest(Server):
|
|||||||
await anyio.sleep(2.0)
|
await anyio.sleep(2.0)
|
||||||
return f"Slow response from {uri.host}"
|
return f"Slow response from {uri.host}"
|
||||||
|
|
||||||
raise McpError(
|
raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found"))
|
||||||
error=ErrorData(
|
|
||||||
code=404, message="OOPS! no resource with that URI was found"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@self.list_tools()
|
@self.list_tools()
|
||||||
async def handle_list_tools() -> list[Tool]:
|
async def handle_list_tools() -> list[Tool]:
|
||||||
@@ -81,12 +77,8 @@ def make_server_app() -> Starlette:
|
|||||||
server = ServerTest()
|
server = ServerTest()
|
||||||
|
|
||||||
async def handle_ws(websocket):
|
async def handle_ws(websocket):
|
||||||
async with websocket_server(
|
async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams:
|
||||||
websocket.scope, websocket.receive, websocket.send
|
await server.run(streams[0], streams[1], server.create_initialization_options())
|
||||||
) as streams:
|
|
||||||
await server.run(
|
|
||||||
streams[0], streams[1], server.create_initialization_options()
|
|
||||||
)
|
|
||||||
|
|
||||||
app = Starlette(
|
app = Starlette(
|
||||||
routes=[
|
routes=[
|
||||||
@@ -99,11 +91,7 @@ def make_server_app() -> Starlette:
|
|||||||
|
|
||||||
def run_server(server_port: int) -> None:
|
def run_server(server_port: int) -> None:
|
||||||
app = make_server_app()
|
app = make_server_app()
|
||||||
server = uvicorn.Server(
|
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
|
||||||
config=uvicorn.Config(
|
|
||||||
app=app, host="127.0.0.1", port=server_port, log_level="error"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(f"starting server on {server_port}")
|
print(f"starting server on {server_port}")
|
||||||
server.run()
|
server.run()
|
||||||
|
|
||||||
@@ -115,9 +103,7 @@ def run_server(server_port: int) -> None:
|
|||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def server(server_port: int) -> Generator[None, None, None]:
|
def server(server_port: int) -> Generator[None, None, None]:
|
||||||
proc = multiprocessing.Process(
|
proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True)
|
||||||
target=run_server, kwargs={"server_port": server_port}, daemon=True
|
|
||||||
)
|
|
||||||
print("starting process")
|
print("starting process")
|
||||||
proc.start()
|
proc.start()
|
||||||
|
|
||||||
@@ -147,9 +133,7 @@ def server(server_port: int) -> Generator[None, None, None]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
async def initialized_ws_client_session(
|
async def initialized_ws_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]:
|
||||||
server, server_url: str
|
|
||||||
) -> AsyncGenerator[ClientSession, None]:
|
|
||||||
"""Create and initialize a WebSocket client session"""
|
"""Create and initialize a WebSocket client session"""
|
||||||
async with websocket_client(server_url + "/ws") as streams:
|
async with websocket_client(server_url + "/ws") as streams:
|
||||||
async with ClientSession(*streams) as session:
|
async with ClientSession(*streams) as session:
|
||||||
@@ -186,9 +170,7 @@ async def test_ws_client_happy_request_and_response(
|
|||||||
initialized_ws_client_session: ClientSession,
|
initialized_ws_client_session: ClientSession,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test a successful request and response via WebSocket"""
|
"""Test a successful request and response via WebSocket"""
|
||||||
result = await initialized_ws_client_session.read_resource(
|
result = await initialized_ws_client_session.read_resource(AnyUrl("foobar://example"))
|
||||||
AnyUrl("foobar://example")
|
|
||||||
)
|
|
||||||
assert isinstance(result, ReadResourceResult)
|
assert isinstance(result, ReadResourceResult)
|
||||||
assert isinstance(result.contents, list)
|
assert isinstance(result.contents, list)
|
||||||
assert len(result.contents) > 0
|
assert len(result.contents) > 0
|
||||||
@@ -218,9 +200,7 @@ async def test_ws_client_timeout(
|
|||||||
|
|
||||||
# Now test that we can still use the session after a timeout
|
# Now test that we can still use the session after a timeout
|
||||||
with anyio.fail_after(5): # Longer timeout to allow completion
|
with anyio.fail_after(5): # Longer timeout to allow completion
|
||||||
result = await initialized_ws_client_session.read_resource(
|
result = await initialized_ws_client_session.read_resource(AnyUrl("foobar://example"))
|
||||||
AnyUrl("foobar://example")
|
|
||||||
)
|
|
||||||
assert isinstance(result, ReadResourceResult)
|
assert isinstance(result, ReadResourceResult)
|
||||||
assert isinstance(result.contents, list)
|
assert isinstance(result.contents, list)
|
||||||
assert len(result.contents) > 0
|
assert len(result.contents) > 0
|
||||||
|
|||||||
@@ -31,9 +31,7 @@ async def test_complex_inputs():
|
|||||||
|
|
||||||
async with client_session(mcp._mcp_server) as client:
|
async with client_session(mcp._mcp_server) as client:
|
||||||
tank = {"shrimp": [{"name": "bob"}, {"name": "alice"}]}
|
tank = {"shrimp": [{"name": "bob"}, {"name": "alice"}]}
|
||||||
result = await client.call_tool(
|
result = await client.call_tool("name_shrimp", {"tank": tank, "extra_names": ["charlie"]})
|
||||||
"name_shrimp", {"tank": tank, "extra_names": ["charlie"]}
|
|
||||||
)
|
|
||||||
assert len(result.content) == 3
|
assert len(result.content) == 3
|
||||||
assert isinstance(result.content[0], TextContent)
|
assert isinstance(result.content[0], TextContent)
|
||||||
assert isinstance(result.content[1], TextContent)
|
assert isinstance(result.content[1], TextContent)
|
||||||
@@ -86,9 +84,7 @@ async def test_desktop(monkeypatch):
|
|||||||
def test_docs_examples(example: CodeExample, eval_example: EvalExample):
|
def test_docs_examples(example: CodeExample, eval_example: EvalExample):
|
||||||
ruff_ignore: list[str] = ["F841", "I001"]
|
ruff_ignore: list[str] = ["F841", "I001"]
|
||||||
|
|
||||||
eval_example.set_config(
|
eval_example.set_config(ruff_ignore=ruff_ignore, target_version="py310", line_length=88)
|
||||||
ruff_ignore=ruff_ignore, target_version="py310", line_length=88
|
|
||||||
)
|
|
||||||
|
|
||||||
if eval_example.update_examples: # pragma: no cover
|
if eval_example.update_examples: # pragma: no cover
|
||||||
eval_example.format(example)
|
eval_example.format(example)
|
||||||
|
|||||||
Reference in New Issue
Block a user