diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index 2b3c2aa..d76e195 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -41,14 +41,7 @@ def main( app = Server("mcp-streamable-http-stateless-demo") @app.call_tool() - async def call_tool( - name: str, arguments: dict - ) -> list[ - types.TextContent - | types.ImageContent - | types.AudioContent - | types.EmbeddedResource - ]: + async def call_tool(name: str, arguments: dict) -> list[types.Content]: ctx = app.request_context interval = arguments.get("interval", 1.0) count = arguments.get("count", 5) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index 5115c12..c6b13ab 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -45,14 +45,7 @@ def main( app = Server("mcp-streamable-http-demo") @app.call_tool() - async def call_tool( - name: str, arguments: dict - ) -> list[ - types.TextContent - | types.ImageContent - | types.AudioContent - | types.EmbeddedResource - ]: + async def call_tool(name: str, arguments: dict) -> list[types.Content]: ctx = app.request_context interval = arguments.get("interval", 1.0) count = arguments.get("count", 5) diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index 46f9bbf..29c741c 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -7,9 +7,7 @@ from mcp.shared._httpx_utils import create_mcp_http_client async def fetch_website( url: str, -) -> list[ - types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource -]: +) -> list[types.Content]: headers = { "User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)" } @@ -31,14 +29,7 @@ def main(port: int, transport: str) -> int: app = Server("mcp-website-fetcher") @app.call_tool() - async def fetch_tool( - name: str, arguments: dict - ) -> list[ - types.TextContent - | types.ImageContent - | types.AudioContent - | types.EmbeddedResource - ]: + async def fetch_tool(name: str, arguments: dict) -> list[types.Content]: if name != "fetch": raise ValueError(f"Unknown tool: {name}") if "url" not in arguments: diff --git a/src/mcp/server/fastmcp/prompts/base.py b/src/mcp/server/fastmcp/prompts/base.py index b8a23e6..a269d6a 100644 --- a/src/mcp/server/fastmcp/prompts/base.py +++ b/src/mcp/server/fastmcp/prompts/base.py @@ -7,18 +7,16 @@ from typing import Any, Literal import pydantic_core from pydantic import BaseModel, Field, TypeAdapter, validate_call -from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent - -CONTENT_TYPES = TextContent | ImageContent | AudioContent | EmbeddedResource +from mcp.types import Content, TextContent class Message(BaseModel): """Base class for all prompt messages.""" role: Literal["user", "assistant"] - content: CONTENT_TYPES + content: Content - def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any): + def __init__(self, content: str | Content, **kwargs: Any): if isinstance(content, str): content = TextContent(type="text", text=content) super().__init__(content=content, **kwargs) @@ -29,7 +27,7 @@ class UserMessage(Message): role: Literal["user", "assistant"] = "user" - def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any): + def __init__(self, content: str | Content, **kwargs: Any): super().__init__(content=content, **kwargs) @@ -38,7 +36,7 @@ class AssistantMessage(Message): role: Literal["user", "assistant"] = "assistant" - def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any): + def __init__(self, content: str | Content, **kwargs: Any): super().__init__(content=content, **kwargs) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index ddcd4df..f745663 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -52,10 +52,8 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.shared.context import LifespanContextT, RequestContext, RequestT from mcp.types import ( AnyFunction, - AudioContent, - EmbeddedResource, + Content, GetPromptResult, - ImageContent, TextContent, ToolAnnotations, ) @@ -256,9 +254,7 @@ class FastMCP: request_context = None return Context(request_context=request_context, fastmcp=self) - async def call_tool( - self, name: str, arguments: dict[str, Any] - ) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]: + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[Content]: """Call a tool by name with arguments.""" context = self.get_context() result = await self._tool_manager.call_tool(name, arguments, context=context) @@ -842,12 +838,12 @@ class FastMCP: def _convert_to_content( result: Any, -) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]: +) -> Sequence[Content]: """Convert a result to a sequence of content objects.""" if result is None: return [] - if isinstance(result, TextContent | ImageContent | AudioContent | EmbeddedResource): + if isinstance(result, Content): return [result] if isinstance(result, Image): diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 3615399..7a24781 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -384,9 +384,7 @@ class Server(Generic[LifespanResultT, RequestT]): def decorator( func: Callable[ ..., - Awaitable[ - Iterable[types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource] - ], + Awaitable[Iterable[types.Content]], ], ): logger.debug("Registering handler for CallToolRequest") diff --git a/src/mcp/types.py b/src/mcp/types.py index 5df82c3..2949ed8 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -667,11 +667,14 @@ class EmbeddedResource(BaseModel): model_config = ConfigDict(extra="allow") +Content = TextContent | ImageContent | AudioContent | EmbeddedResource + + class PromptMessage(BaseModel): """Describes a message returned as part of a prompt.""" role: Role - content: TextContent | ImageContent | AudioContent | EmbeddedResource + content: Content model_config = ConfigDict(extra="allow") @@ -787,7 +790,7 @@ class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]): class CallToolResult(Result): """The server's response to a tool call.""" - content: list[TextContent | ImageContent | AudioContent | EmbeddedResource] + content: list[Content] isError: bool = False diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 6fbc700..3647871 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -11,7 +11,7 @@ from anyio.abc import TaskStatus from mcp.client.session import ClientSession from mcp.server.lowlevel import Server from mcp.shared.exceptions import McpError -from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent +from mcp.types import Content, TextContent @pytest.mark.anyio @@ -31,7 +31,7 @@ async def test_notification_validation_error(tmp_path: Path): slow_request_complete = anyio.Event() @server.call_tool() - async def slow_tool(name: str, arg) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]: + async def slow_tool(name: str, arg) -> Sequence[Content]: nonlocal request_count request_count += 1 diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index b2d941d..e734512 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -8,7 +8,7 @@ from pydantic import AnyUrl from starlette.routing import Mount, Route from mcp.server.fastmcp import Context, FastMCP -from mcp.server.fastmcp.prompts.base import EmbeddedResource, Message, UserMessage +from mcp.server.fastmcp.prompts.base import Message, UserMessage from mcp.server.fastmcp.resources import FileResource, FunctionResource from mcp.server.fastmcp.utilities.types import Image from mcp.shared.exceptions import McpError @@ -18,6 +18,8 @@ from mcp.shared.memory import ( from mcp.types import ( AudioContent, BlobResourceContents, + Content, + EmbeddedResource, ImageContent, TextContent, TextResourceContents, @@ -192,7 +194,7 @@ def image_tool_fn(path: str) -> Image: return Image(path) -def mixed_content_tool_fn() -> list[TextContent | ImageContent | AudioContent]: +def mixed_content_tool_fn() -> list[Content]: return [ TextContent(type="text", text="Hello"), ImageContent(type="image", data="abc", mimeType="image/png"),