mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 06:24:26 +01:00
chore: create union for working with message content (#939)
This commit is contained in:
@@ -41,14 +41,7 @@ def main(
|
||||
app = Server("mcp-streamable-http-stateless-demo")
|
||||
|
||||
@app.call_tool()
|
||||
async def call_tool(
|
||||
name: str, arguments: dict
|
||||
) -> list[
|
||||
types.TextContent
|
||||
| types.ImageContent
|
||||
| types.AudioContent
|
||||
| types.EmbeddedResource
|
||||
]:
|
||||
async def call_tool(name: str, arguments: dict) -> list[types.Content]:
|
||||
ctx = app.request_context
|
||||
interval = arguments.get("interval", 1.0)
|
||||
count = arguments.get("count", 5)
|
||||
|
||||
@@ -45,14 +45,7 @@ def main(
|
||||
app = Server("mcp-streamable-http-demo")
|
||||
|
||||
@app.call_tool()
|
||||
async def call_tool(
|
||||
name: str, arguments: dict
|
||||
) -> list[
|
||||
types.TextContent
|
||||
| types.ImageContent
|
||||
| types.AudioContent
|
||||
| types.EmbeddedResource
|
||||
]:
|
||||
async def call_tool(name: str, arguments: dict) -> list[types.Content]:
|
||||
ctx = app.request_context
|
||||
interval = arguments.get("interval", 1.0)
|
||||
count = arguments.get("count", 5)
|
||||
|
||||
@@ -7,9 +7,7 @@ from mcp.shared._httpx_utils import create_mcp_http_client
|
||||
|
||||
async def fetch_website(
|
||||
url: str,
|
||||
) -> list[
|
||||
types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource
|
||||
]:
|
||||
) -> list[types.Content]:
|
||||
headers = {
|
||||
"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"
|
||||
}
|
||||
@@ -31,14 +29,7 @@ def main(port: int, transport: str) -> int:
|
||||
app = Server("mcp-website-fetcher")
|
||||
|
||||
@app.call_tool()
|
||||
async def fetch_tool(
|
||||
name: str, arguments: dict
|
||||
) -> list[
|
||||
types.TextContent
|
||||
| types.ImageContent
|
||||
| types.AudioContent
|
||||
| types.EmbeddedResource
|
||||
]:
|
||||
async def fetch_tool(name: str, arguments: dict) -> list[types.Content]:
|
||||
if name != "fetch":
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
if "url" not in arguments:
|
||||
|
||||
@@ -7,18 +7,16 @@ from typing import Any, Literal
|
||||
import pydantic_core
|
||||
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
||||
|
||||
from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent
|
||||
|
||||
CONTENT_TYPES = TextContent | ImageContent | AudioContent | EmbeddedResource
|
||||
from mcp.types import Content, TextContent
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Base class for all prompt messages."""
|
||||
|
||||
role: Literal["user", "assistant"]
|
||||
content: CONTENT_TYPES
|
||||
content: Content
|
||||
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
||||
def __init__(self, content: str | Content, **kwargs: Any):
|
||||
if isinstance(content, str):
|
||||
content = TextContent(type="text", text=content)
|
||||
super().__init__(content=content, **kwargs)
|
||||
@@ -29,7 +27,7 @@ class UserMessage(Message):
|
||||
|
||||
role: Literal["user", "assistant"] = "user"
|
||||
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
||||
def __init__(self, content: str | Content, **kwargs: Any):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
@@ -38,7 +36,7 @@ class AssistantMessage(Message):
|
||||
|
||||
role: Literal["user", "assistant"] = "assistant"
|
||||
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
||||
def __init__(self, content: str | Content, **kwargs: Any):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -52,10 +52,8 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
|
||||
from mcp.types import (
|
||||
AnyFunction,
|
||||
AudioContent,
|
||||
EmbeddedResource,
|
||||
Content,
|
||||
GetPromptResult,
|
||||
ImageContent,
|
||||
TextContent,
|
||||
ToolAnnotations,
|
||||
)
|
||||
@@ -256,9 +254,7 @@ class FastMCP:
|
||||
request_context = None
|
||||
return Context(request_context=request_context, fastmcp=self)
|
||||
|
||||
async def call_tool(
|
||||
self, name: str, arguments: dict[str, Any]
|
||||
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
|
||||
async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[Content]:
|
||||
"""Call a tool by name with arguments."""
|
||||
context = self.get_context()
|
||||
result = await self._tool_manager.call_tool(name, arguments, context=context)
|
||||
@@ -842,12 +838,12 @@ class FastMCP:
|
||||
|
||||
def _convert_to_content(
|
||||
result: Any,
|
||||
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
|
||||
) -> Sequence[Content]:
|
||||
"""Convert a result to a sequence of content objects."""
|
||||
if result is None:
|
||||
return []
|
||||
|
||||
if isinstance(result, TextContent | ImageContent | AudioContent | EmbeddedResource):
|
||||
if isinstance(result, Content):
|
||||
return [result]
|
||||
|
||||
if isinstance(result, Image):
|
||||
|
||||
@@ -384,9 +384,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
def decorator(
|
||||
func: Callable[
|
||||
...,
|
||||
Awaitable[
|
||||
Iterable[types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource]
|
||||
],
|
||||
Awaitable[Iterable[types.Content]],
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for CallToolRequest")
|
||||
|
||||
@@ -667,11 +667,14 @@ class EmbeddedResource(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
Content = TextContent | ImageContent | AudioContent | EmbeddedResource
|
||||
|
||||
|
||||
class PromptMessage(BaseModel):
|
||||
"""Describes a message returned as part of a prompt."""
|
||||
|
||||
role: Role
|
||||
content: TextContent | ImageContent | AudioContent | EmbeddedResource
|
||||
content: Content
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -787,7 +790,7 @@ class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
|
||||
class CallToolResult(Result):
|
||||
"""The server's response to a tool call."""
|
||||
|
||||
content: list[TextContent | ImageContent | AudioContent | EmbeddedResource]
|
||||
content: list[Content]
|
||||
isError: bool = False
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from anyio.abc import TaskStatus
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.server.lowlevel import Server
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent
|
||||
from mcp.types import Content, TextContent
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -31,7 +31,7 @@ async def test_notification_validation_error(tmp_path: Path):
|
||||
slow_request_complete = anyio.Event()
|
||||
|
||||
@server.call_tool()
|
||||
async def slow_tool(name: str, arg) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
|
||||
async def slow_tool(name: str, arg) -> Sequence[Content]:
|
||||
nonlocal request_count
|
||||
request_count += 1
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic import AnyUrl
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
from mcp.server.fastmcp.prompts.base import EmbeddedResource, Message, UserMessage
|
||||
from mcp.server.fastmcp.prompts.base import Message, UserMessage
|
||||
from mcp.server.fastmcp.resources import FileResource, FunctionResource
|
||||
from mcp.server.fastmcp.utilities.types import Image
|
||||
from mcp.shared.exceptions import McpError
|
||||
@@ -18,6 +18,8 @@ from mcp.shared.memory import (
|
||||
from mcp.types import (
|
||||
AudioContent,
|
||||
BlobResourceContents,
|
||||
Content,
|
||||
EmbeddedResource,
|
||||
ImageContent,
|
||||
TextContent,
|
||||
TextResourceContents,
|
||||
@@ -192,7 +194,7 @@ def image_tool_fn(path: str) -> Image:
|
||||
return Image(path)
|
||||
|
||||
|
||||
def mixed_content_tool_fn() -> list[TextContent | ImageContent | AudioContent]:
|
||||
def mixed_content_tool_fn() -> list[Content]:
|
||||
return [
|
||||
TextContent(type="text", text="Hello"),
|
||||
ImageContent(type="image", data="abc", mimeType="image/png"),
|
||||
|
||||
Reference in New Issue
Block a user