chore: create union for working with message content (#939)

This commit is contained in:
Luca Chang
2025-06-12 00:01:33 -07:00
committed by GitHub
parent 185fa49fd1
commit d69b290b65
9 changed files with 25 additions and 51 deletions

View File

@@ -7,18 +7,16 @@ from typing import Any, Literal
import pydantic_core
from pydantic import BaseModel, Field, TypeAdapter, validate_call
from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent
CONTENT_TYPES = TextContent | ImageContent | AudioContent | EmbeddedResource
from mcp.types import Content, TextContent
class Message(BaseModel):
"""Base class for all prompt messages."""
role: Literal["user", "assistant"]
content: CONTENT_TYPES
content: Content
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
def __init__(self, content: str | Content, **kwargs: Any):
if isinstance(content, str):
content = TextContent(type="text", text=content)
super().__init__(content=content, **kwargs)
@@ -29,7 +27,7 @@ class UserMessage(Message):
role: Literal["user", "assistant"] = "user"
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
def __init__(self, content: str | Content, **kwargs: Any):
super().__init__(content=content, **kwargs)
@@ -38,7 +36,7 @@ class AssistantMessage(Message):
role: Literal["user", "assistant"] = "assistant"
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
def __init__(self, content: str | Content, **kwargs: Any):
super().__init__(content=content, **kwargs)

View File

@@ -52,10 +52,8 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
from mcp.types import (
AnyFunction,
AudioContent,
EmbeddedResource,
Content,
GetPromptResult,
ImageContent,
TextContent,
ToolAnnotations,
)
@@ -256,9 +254,7 @@ class FastMCP:
request_context = None
return Context(request_context=request_context, fastmcp=self)
async def call_tool(
self, name: str, arguments: dict[str, Any]
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[Content]:
"""Call a tool by name with arguments."""
context = self.get_context()
result = await self._tool_manager.call_tool(name, arguments, context=context)
@@ -842,12 +838,12 @@ class FastMCP:
def _convert_to_content(
result: Any,
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
) -> Sequence[Content]:
"""Convert a result to a sequence of content objects."""
if result is None:
return []
if isinstance(result, TextContent | ImageContent | AudioContent | EmbeddedResource):
if isinstance(result, Content):
return [result]
if isinstance(result, Image):

View File

@@ -384,9 +384,7 @@ class Server(Generic[LifespanResultT, RequestT]):
def decorator(
func: Callable[
...,
Awaitable[
Iterable[types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource]
],
Awaitable[Iterable[types.Content]],
],
):
logger.debug("Registering handler for CallToolRequest")