mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 14:34:27 +01:00
chore: create union for working with message content (#939)
This commit is contained in:
@@ -41,14 +41,7 @@ def main(
|
|||||||
app = Server("mcp-streamable-http-stateless-demo")
|
app = Server("mcp-streamable-http-stateless-demo")
|
||||||
|
|
||||||
@app.call_tool()
|
@app.call_tool()
|
||||||
async def call_tool(
|
async def call_tool(name: str, arguments: dict) -> list[types.Content]:
|
||||||
name: str, arguments: dict
|
|
||||||
) -> list[
|
|
||||||
types.TextContent
|
|
||||||
| types.ImageContent
|
|
||||||
| types.AudioContent
|
|
||||||
| types.EmbeddedResource
|
|
||||||
]:
|
|
||||||
ctx = app.request_context
|
ctx = app.request_context
|
||||||
interval = arguments.get("interval", 1.0)
|
interval = arguments.get("interval", 1.0)
|
||||||
count = arguments.get("count", 5)
|
count = arguments.get("count", 5)
|
||||||
|
|||||||
@@ -45,14 +45,7 @@ def main(
|
|||||||
app = Server("mcp-streamable-http-demo")
|
app = Server("mcp-streamable-http-demo")
|
||||||
|
|
||||||
@app.call_tool()
|
@app.call_tool()
|
||||||
async def call_tool(
|
async def call_tool(name: str, arguments: dict) -> list[types.Content]:
|
||||||
name: str, arguments: dict
|
|
||||||
) -> list[
|
|
||||||
types.TextContent
|
|
||||||
| types.ImageContent
|
|
||||||
| types.AudioContent
|
|
||||||
| types.EmbeddedResource
|
|
||||||
]:
|
|
||||||
ctx = app.request_context
|
ctx = app.request_context
|
||||||
interval = arguments.get("interval", 1.0)
|
interval = arguments.get("interval", 1.0)
|
||||||
count = arguments.get("count", 5)
|
count = arguments.get("count", 5)
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ from mcp.shared._httpx_utils import create_mcp_http_client
|
|||||||
|
|
||||||
async def fetch_website(
|
async def fetch_website(
|
||||||
url: str,
|
url: str,
|
||||||
) -> list[
|
) -> list[types.Content]:
|
||||||
types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource
|
|
||||||
]:
|
|
||||||
headers = {
|
headers = {
|
||||||
"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"
|
"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"
|
||||||
}
|
}
|
||||||
@@ -31,14 +29,7 @@ def main(port: int, transport: str) -> int:
|
|||||||
app = Server("mcp-website-fetcher")
|
app = Server("mcp-website-fetcher")
|
||||||
|
|
||||||
@app.call_tool()
|
@app.call_tool()
|
||||||
async def fetch_tool(
|
async def fetch_tool(name: str, arguments: dict) -> list[types.Content]:
|
||||||
name: str, arguments: dict
|
|
||||||
) -> list[
|
|
||||||
types.TextContent
|
|
||||||
| types.ImageContent
|
|
||||||
| types.AudioContent
|
|
||||||
| types.EmbeddedResource
|
|
||||||
]:
|
|
||||||
if name != "fetch":
|
if name != "fetch":
|
||||||
raise ValueError(f"Unknown tool: {name}")
|
raise ValueError(f"Unknown tool: {name}")
|
||||||
if "url" not in arguments:
|
if "url" not in arguments:
|
||||||
|
|||||||
@@ -7,18 +7,16 @@ from typing import Any, Literal
|
|||||||
import pydantic_core
|
import pydantic_core
|
||||||
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
||||||
|
|
||||||
from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent
|
from mcp.types import Content, TextContent
|
||||||
|
|
||||||
CONTENT_TYPES = TextContent | ImageContent | AudioContent | EmbeddedResource
|
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
"""Base class for all prompt messages."""
|
"""Base class for all prompt messages."""
|
||||||
|
|
||||||
role: Literal["user", "assistant"]
|
role: Literal["user", "assistant"]
|
||||||
content: CONTENT_TYPES
|
content: Content
|
||||||
|
|
||||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
def __init__(self, content: str | Content, **kwargs: Any):
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
content = TextContent(type="text", text=content)
|
content = TextContent(type="text", text=content)
|
||||||
super().__init__(content=content, **kwargs)
|
super().__init__(content=content, **kwargs)
|
||||||
@@ -29,7 +27,7 @@ class UserMessage(Message):
|
|||||||
|
|
||||||
role: Literal["user", "assistant"] = "user"
|
role: Literal["user", "assistant"] = "user"
|
||||||
|
|
||||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
def __init__(self, content: str | Content, **kwargs: Any):
|
||||||
super().__init__(content=content, **kwargs)
|
super().__init__(content=content, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@@ -38,7 +36,7 @@ class AssistantMessage(Message):
|
|||||||
|
|
||||||
role: Literal["user", "assistant"] = "assistant"
|
role: Literal["user", "assistant"] = "assistant"
|
||||||
|
|
||||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
|
def __init__(self, content: str | Content, **kwargs: Any):
|
||||||
super().__init__(content=content, **kwargs)
|
super().__init__(content=content, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -52,10 +52,8 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
|||||||
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
|
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
AnyFunction,
|
AnyFunction,
|
||||||
AudioContent,
|
Content,
|
||||||
EmbeddedResource,
|
|
||||||
GetPromptResult,
|
GetPromptResult,
|
||||||
ImageContent,
|
|
||||||
TextContent,
|
TextContent,
|
||||||
ToolAnnotations,
|
ToolAnnotations,
|
||||||
)
|
)
|
||||||
@@ -256,9 +254,7 @@ class FastMCP:
|
|||||||
request_context = None
|
request_context = None
|
||||||
return Context(request_context=request_context, fastmcp=self)
|
return Context(request_context=request_context, fastmcp=self)
|
||||||
|
|
||||||
async def call_tool(
|
async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[Content]:
|
||||||
self, name: str, arguments: dict[str, Any]
|
|
||||||
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
|
|
||||||
"""Call a tool by name with arguments."""
|
"""Call a tool by name with arguments."""
|
||||||
context = self.get_context()
|
context = self.get_context()
|
||||||
result = await self._tool_manager.call_tool(name, arguments, context=context)
|
result = await self._tool_manager.call_tool(name, arguments, context=context)
|
||||||
@@ -842,12 +838,12 @@ class FastMCP:
|
|||||||
|
|
||||||
def _convert_to_content(
|
def _convert_to_content(
|
||||||
result: Any,
|
result: Any,
|
||||||
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
|
) -> Sequence[Content]:
|
||||||
"""Convert a result to a sequence of content objects."""
|
"""Convert a result to a sequence of content objects."""
|
||||||
if result is None:
|
if result is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if isinstance(result, TextContent | ImageContent | AudioContent | EmbeddedResource):
|
if isinstance(result, Content):
|
||||||
return [result]
|
return [result]
|
||||||
|
|
||||||
if isinstance(result, Image):
|
if isinstance(result, Image):
|
||||||
|
|||||||
@@ -384,9 +384,7 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
def decorator(
|
def decorator(
|
||||||
func: Callable[
|
func: Callable[
|
||||||
...,
|
...,
|
||||||
Awaitable[
|
Awaitable[Iterable[types.Content]],
|
||||||
Iterable[types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource]
|
|
||||||
],
|
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
logger.debug("Registering handler for CallToolRequest")
|
logger.debug("Registering handler for CallToolRequest")
|
||||||
|
|||||||
@@ -667,11 +667,14 @@ class EmbeddedResource(BaseModel):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
Content = TextContent | ImageContent | AudioContent | EmbeddedResource
|
||||||
|
|
||||||
|
|
||||||
class PromptMessage(BaseModel):
|
class PromptMessage(BaseModel):
|
||||||
"""Describes a message returned as part of a prompt."""
|
"""Describes a message returned as part of a prompt."""
|
||||||
|
|
||||||
role: Role
|
role: Role
|
||||||
content: TextContent | ImageContent | AudioContent | EmbeddedResource
|
content: Content
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
@@ -787,7 +790,7 @@ class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
|
|||||||
class CallToolResult(Result):
|
class CallToolResult(Result):
|
||||||
"""The server's response to a tool call."""
|
"""The server's response to a tool call."""
|
||||||
|
|
||||||
content: list[TextContent | ImageContent | AudioContent | EmbeddedResource]
|
content: list[Content]
|
||||||
isError: bool = False
|
isError: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from anyio.abc import TaskStatus
|
|||||||
from mcp.client.session import ClientSession
|
from mcp.client.session import ClientSession
|
||||||
from mcp.server.lowlevel import Server
|
from mcp.server.lowlevel import Server
|
||||||
from mcp.shared.exceptions import McpError
|
from mcp.shared.exceptions import McpError
|
||||||
from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent
|
from mcp.types import Content, TextContent
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -31,7 +31,7 @@ async def test_notification_validation_error(tmp_path: Path):
|
|||||||
slow_request_complete = anyio.Event()
|
slow_request_complete = anyio.Event()
|
||||||
|
|
||||||
@server.call_tool()
|
@server.call_tool()
|
||||||
async def slow_tool(name: str, arg) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
|
async def slow_tool(name: str, arg) -> Sequence[Content]:
|
||||||
nonlocal request_count
|
nonlocal request_count
|
||||||
request_count += 1
|
request_count += 1
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from pydantic import AnyUrl
|
|||||||
from starlette.routing import Mount, Route
|
from starlette.routing import Mount, Route
|
||||||
|
|
||||||
from mcp.server.fastmcp import Context, FastMCP
|
from mcp.server.fastmcp import Context, FastMCP
|
||||||
from mcp.server.fastmcp.prompts.base import EmbeddedResource, Message, UserMessage
|
from mcp.server.fastmcp.prompts.base import Message, UserMessage
|
||||||
from mcp.server.fastmcp.resources import FileResource, FunctionResource
|
from mcp.server.fastmcp.resources import FileResource, FunctionResource
|
||||||
from mcp.server.fastmcp.utilities.types import Image
|
from mcp.server.fastmcp.utilities.types import Image
|
||||||
from mcp.shared.exceptions import McpError
|
from mcp.shared.exceptions import McpError
|
||||||
@@ -18,6 +18,8 @@ from mcp.shared.memory import (
|
|||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
AudioContent,
|
AudioContent,
|
||||||
BlobResourceContents,
|
BlobResourceContents,
|
||||||
|
Content,
|
||||||
|
EmbeddedResource,
|
||||||
ImageContent,
|
ImageContent,
|
||||||
TextContent,
|
TextContent,
|
||||||
TextResourceContents,
|
TextResourceContents,
|
||||||
@@ -192,7 +194,7 @@ def image_tool_fn(path: str) -> Image:
|
|||||||
return Image(path)
|
return Image(path)
|
||||||
|
|
||||||
|
|
||||||
def mixed_content_tool_fn() -> list[TextContent | ImageContent | AudioContent]:
|
def mixed_content_tool_fn() -> list[Content]:
|
||||||
return [
|
return [
|
||||||
TextContent(type="text", text="Hello"),
|
TextContent(type="text", text="Hello"),
|
||||||
ImageContent(type="image", data="abc", mimeType="image/png"),
|
ImageContent(type="image", data="abc", mimeType="image/png"),
|
||||||
|
|||||||
Reference in New Issue
Block a user