chore: create union for working with message content (#939)

This commit is contained in:
Luca Chang
2025-06-12 00:01:33 -07:00
committed by GitHub
parent 185fa49fd1
commit d69b290b65
9 changed files with 25 additions and 51 deletions

View File

@@ -41,14 +41,7 @@ def main(
app = Server("mcp-streamable-http-stateless-demo")
@app.call_tool()
async def call_tool(
name: str, arguments: dict
) -> list[
types.TextContent
| types.ImageContent
| types.AudioContent
| types.EmbeddedResource
]:
async def call_tool(name: str, arguments: dict) -> list[types.Content]:
ctx = app.request_context
interval = arguments.get("interval", 1.0)
count = arguments.get("count", 5)

View File

@@ -45,14 +45,7 @@ def main(
app = Server("mcp-streamable-http-demo")
@app.call_tool()
async def call_tool(
name: str, arguments: dict
) -> list[
types.TextContent
| types.ImageContent
| types.AudioContent
| types.EmbeddedResource
]:
async def call_tool(name: str, arguments: dict) -> list[types.Content]:
ctx = app.request_context
interval = arguments.get("interval", 1.0)
count = arguments.get("count", 5)

View File

@@ -7,9 +7,7 @@ from mcp.shared._httpx_utils import create_mcp_http_client
async def fetch_website(
url: str,
) -> list[
types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource
]:
) -> list[types.Content]:
headers = {
"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"
}
@@ -31,14 +29,7 @@ def main(port: int, transport: str) -> int:
app = Server("mcp-website-fetcher")
@app.call_tool()
async def fetch_tool(
name: str, arguments: dict
) -> list[
types.TextContent
| types.ImageContent
| types.AudioContent
| types.EmbeddedResource
]:
async def fetch_tool(name: str, arguments: dict) -> list[types.Content]:
if name != "fetch":
raise ValueError(f"Unknown tool: {name}")
if "url" not in arguments:

View File

@@ -7,18 +7,16 @@ from typing import Any, Literal
import pydantic_core
from pydantic import BaseModel, Field, TypeAdapter, validate_call
from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent
CONTENT_TYPES = TextContent | ImageContent | AudioContent | EmbeddedResource
from mcp.types import Content, TextContent
class Message(BaseModel):
"""Base class for all prompt messages."""
role: Literal["user", "assistant"]
content: CONTENT_TYPES
content: Content
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
def __init__(self, content: str | Content, **kwargs: Any):
if isinstance(content, str):
content = TextContent(type="text", text=content)
super().__init__(content=content, **kwargs)
@@ -29,7 +27,7 @@ class UserMessage(Message):
role: Literal["user", "assistant"] = "user"
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
def __init__(self, content: str | Content, **kwargs: Any):
super().__init__(content=content, **kwargs)
@@ -38,7 +36,7 @@ class AssistantMessage(Message):
role: Literal["user", "assistant"] = "assistant"
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
def __init__(self, content: str | Content, **kwargs: Any):
super().__init__(content=content, **kwargs)

View File

@@ -52,10 +52,8 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
from mcp.types import (
AnyFunction,
AudioContent,
EmbeddedResource,
Content,
GetPromptResult,
ImageContent,
TextContent,
ToolAnnotations,
)
@@ -256,9 +254,7 @@ class FastMCP:
request_context = None
return Context(request_context=request_context, fastmcp=self)
async def call_tool(
self, name: str, arguments: dict[str, Any]
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[Content]:
"""Call a tool by name with arguments."""
context = self.get_context()
result = await self._tool_manager.call_tool(name, arguments, context=context)
@@ -842,12 +838,12 @@ class FastMCP:
def _convert_to_content(
result: Any,
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
) -> Sequence[Content]:
"""Convert a result to a sequence of content objects."""
if result is None:
return []
if isinstance(result, TextContent | ImageContent | AudioContent | EmbeddedResource):
if isinstance(result, Content):
return [result]
if isinstance(result, Image):

View File

@@ -384,9 +384,7 @@ class Server(Generic[LifespanResultT, RequestT]):
def decorator(
func: Callable[
...,
Awaitable[
Iterable[types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource]
],
Awaitable[Iterable[types.Content]],
],
):
logger.debug("Registering handler for CallToolRequest")

View File

@@ -667,11 +667,14 @@ class EmbeddedResource(BaseModel):
model_config = ConfigDict(extra="allow")
Content = TextContent | ImageContent | AudioContent | EmbeddedResource
class PromptMessage(BaseModel):
"""Describes a message returned as part of a prompt."""
role: Role
content: TextContent | ImageContent | AudioContent | EmbeddedResource
content: Content
model_config = ConfigDict(extra="allow")
@@ -787,7 +790,7 @@ class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
class CallToolResult(Result):
"""The server's response to a tool call."""
content: list[TextContent | ImageContent | AudioContent | EmbeddedResource]
content: list[Content]
isError: bool = False

View File

@@ -11,7 +11,7 @@ from anyio.abc import TaskStatus
from mcp.client.session import ClientSession
from mcp.server.lowlevel import Server
from mcp.shared.exceptions import McpError
from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent
from mcp.types import Content, TextContent
@pytest.mark.anyio
@@ -31,7 +31,7 @@ async def test_notification_validation_error(tmp_path: Path):
slow_request_complete = anyio.Event()
@server.call_tool()
async def slow_tool(name: str, arg) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
async def slow_tool(name: str, arg) -> Sequence[Content]:
nonlocal request_count
request_count += 1

View File

@@ -8,7 +8,7 @@ from pydantic import AnyUrl
from starlette.routing import Mount, Route
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.fastmcp.prompts.base import EmbeddedResource, Message, UserMessage
from mcp.server.fastmcp.prompts.base import Message, UserMessage
from mcp.server.fastmcp.resources import FileResource, FunctionResource
from mcp.server.fastmcp.utilities.types import Image
from mcp.shared.exceptions import McpError
@@ -18,6 +18,8 @@ from mcp.shared.memory import (
from mcp.types import (
AudioContent,
BlobResourceContents,
Content,
EmbeddedResource,
ImageContent,
TextContent,
TextResourceContents,
@@ -192,7 +194,7 @@ def image_tool_fn(path: str) -> Image:
return Image(path)
def mixed_content_tool_fn() -> list[TextContent | ImageContent | AudioContent]:
def mixed_content_tool_fn() -> list[Content]:
return [
TextContent(type="text", text="Hello"),
ImageContent(type="image", data="abc", mimeType="image/png"),