mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-23 00:34:27 +01:00
chore: create union for working with message content (#939)
This commit is contained in:
@@ -7,18 +7,16 @@ from typing import Any, Literal
|
||||
import pydantic_core
|
||||
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
||||
|
||||
from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent
|
||||
|
||||
CONTENT_TYPES = TextContent | ImageContent | AudioContent | EmbeddedResource
|
||||
from mcp.types import Content, TextContent
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Base class for all prompt messages."""
|
||||
|
||||
role: Literal["user", "assistant"]
|
||||
content: CONTENT_TYPES
|
||||
content: Content
|
||||
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
||||
def __init__(self, content: str | Content, **kwargs: Any):
|
||||
if isinstance(content, str):
|
||||
content = TextContent(type="text", text=content)
|
||||
super().__init__(content=content, **kwargs)
|
||||
@@ -29,7 +27,7 @@ class UserMessage(Message):
|
||||
|
||||
role: Literal["user", "assistant"] = "user"
|
||||
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
||||
def __init__(self, content: str | Content, **kwargs: Any):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
@@ -38,7 +36,7 @@ class AssistantMessage(Message):
|
||||
|
||||
role: Literal["user", "assistant"] = "assistant"
|
||||
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
||||
def __init__(self, content: str | Content, **kwargs: Any):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -52,10 +52,8 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
|
||||
from mcp.types import (
|
||||
AnyFunction,
|
||||
AudioContent,
|
||||
EmbeddedResource,
|
||||
Content,
|
||||
GetPromptResult,
|
||||
ImageContent,
|
||||
TextContent,
|
||||
ToolAnnotations,
|
||||
)
|
||||
@@ -256,9 +254,7 @@ class FastMCP:
|
||||
request_context = None
|
||||
return Context(request_context=request_context, fastmcp=self)
|
||||
|
||||
async def call_tool(
|
||||
self, name: str, arguments: dict[str, Any]
|
||||
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
|
||||
async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[Content]:
|
||||
"""Call a tool by name with arguments."""
|
||||
context = self.get_context()
|
||||
result = await self._tool_manager.call_tool(name, arguments, context=context)
|
||||
@@ -842,12 +838,12 @@ class FastMCP:
|
||||
|
||||
def _convert_to_content(
|
||||
result: Any,
|
||||
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
|
||||
) -> Sequence[Content]:
|
||||
"""Convert a result to a sequence of content objects."""
|
||||
if result is None:
|
||||
return []
|
||||
|
||||
if isinstance(result, TextContent | ImageContent | AudioContent | EmbeddedResource):
|
||||
if isinstance(result, Content):
|
||||
return [result]
|
||||
|
||||
if isinstance(result, Image):
|
||||
|
||||
@@ -384,9 +384,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
def decorator(
|
||||
func: Callable[
|
||||
...,
|
||||
Awaitable[
|
||||
Iterable[types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource]
|
||||
],
|
||||
Awaitable[Iterable[types.Content]],
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for CallToolRequest")
|
||||
|
||||
Reference in New Issue
Block a user