mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 22:44:20 +01:00
feat: support audio content (#725)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
@@ -43,7 +43,12 @@ def main(
|
||||
@app.call_tool()
|
||||
async def call_tool(
|
||||
name: str, arguments: dict
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
) -> list[
|
||||
types.TextContent
|
||||
| types.ImageContent
|
||||
| types.AudioContent
|
||||
| types.EmbeddedResource
|
||||
]:
|
||||
ctx = app.request_context
|
||||
interval = arguments.get("interval", 1.0)
|
||||
count = arguments.get("count", 5)
|
||||
|
||||
@@ -47,7 +47,12 @@ def main(
|
||||
@app.call_tool()
|
||||
async def call_tool(
|
||||
name: str, arguments: dict
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
) -> list[
|
||||
types.TextContent
|
||||
| types.ImageContent
|
||||
| types.AudioContent
|
||||
| types.EmbeddedResource
|
||||
]:
|
||||
ctx = app.request_context
|
||||
interval = arguments.get("interval", 1.0)
|
||||
count = arguments.get("count", 5)
|
||||
|
||||
@@ -7,7 +7,9 @@ from mcp.shared._httpx_utils import create_mcp_http_client
|
||||
|
||||
async def fetch_website(
|
||||
url: str,
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
) -> list[
|
||||
types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource
|
||||
]:
|
||||
headers = {
|
||||
"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"
|
||||
}
|
||||
@@ -31,7 +33,12 @@ def main(port: int, transport: str) -> int:
|
||||
@app.call_tool()
|
||||
async def fetch_tool(
|
||||
name: str, arguments: dict
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
) -> list[
|
||||
types.TextContent
|
||||
| types.ImageContent
|
||||
| types.AudioContent
|
||||
| types.EmbeddedResource
|
||||
]:
|
||||
if name != "fetch":
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
if "url" not in arguments:
|
||||
|
||||
@@ -7,9 +7,9 @@ from typing import Any, Literal
|
||||
import pydantic_core
|
||||
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
||||
|
||||
from mcp.types import EmbeddedResource, ImageContent, TextContent
|
||||
from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent
|
||||
|
||||
CONTENT_TYPES = TextContent | ImageContent | EmbeddedResource
|
||||
CONTENT_TYPES = TextContent | ImageContent | AudioContent | EmbeddedResource
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
|
||||
@@ -52,6 +52,7 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
|
||||
from mcp.types import (
|
||||
AnyFunction,
|
||||
AudioContent,
|
||||
EmbeddedResource,
|
||||
GetPromptResult,
|
||||
ImageContent,
|
||||
@@ -275,7 +276,7 @@ class FastMCP:
|
||||
|
||||
async def call_tool(
|
||||
self, name: str, arguments: dict[str, Any]
|
||||
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
|
||||
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
|
||||
"""Call a tool by name with arguments."""
|
||||
context = self.get_context()
|
||||
result = await self._tool_manager.call_tool(name, arguments, context=context)
|
||||
@@ -875,12 +876,12 @@ class FastMCP:
|
||||
|
||||
def _convert_to_content(
|
||||
result: Any,
|
||||
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
|
||||
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
|
||||
"""Convert a result to a sequence of content objects."""
|
||||
if result is None:
|
||||
return []
|
||||
|
||||
if isinstance(result, TextContent | ImageContent | EmbeddedResource):
|
||||
if isinstance(result, TextContent | ImageContent | AudioContent | EmbeddedResource):
|
||||
return [result]
|
||||
|
||||
if isinstance(result, Image):
|
||||
|
||||
@@ -405,7 +405,10 @@ class Server(Generic[LifespanResultT, RequestT]):
|
||||
...,
|
||||
Awaitable[
|
||||
Iterable[
|
||||
types.TextContent | types.ImageContent | types.EmbeddedResource
|
||||
types.TextContent
|
||||
| types.ImageContent
|
||||
| types.AudioContent
|
||||
| types.EmbeddedResource
|
||||
]
|
||||
],
|
||||
],
|
||||
|
||||
@@ -657,11 +657,26 @@ class ImageContent(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class AudioContent(BaseModel):
|
||||
"""Audio content for a message."""
|
||||
|
||||
type: Literal["audio"]
|
||||
data: str
|
||||
"""The base64-encoded audio data."""
|
||||
mimeType: str
|
||||
"""
|
||||
The MIME type of the audio. Different providers may support different
|
||||
audio types.
|
||||
"""
|
||||
annotations: Annotations | None = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class SamplingMessage(BaseModel):
|
||||
"""Describes a message issued to or received from an LLM API."""
|
||||
|
||||
role: Role
|
||||
content: TextContent | ImageContent
|
||||
content: TextContent | ImageContent | AudioContent
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -683,7 +698,7 @@ class PromptMessage(BaseModel):
|
||||
"""Describes a message returned as part of a prompt."""
|
||||
|
||||
role: Role
|
||||
content: TextContent | ImageContent | EmbeddedResource
|
||||
content: TextContent | ImageContent | AudioContent | EmbeddedResource
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -801,7 +816,7 @@ class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
|
||||
class CallToolResult(Result):
|
||||
"""The server's response to a tool call."""
|
||||
|
||||
content: list[TextContent | ImageContent | EmbeddedResource]
|
||||
content: list[TextContent | ImageContent | AudioContent | EmbeddedResource]
|
||||
isError: bool = False
|
||||
|
||||
|
||||
@@ -965,7 +980,7 @@ class CreateMessageResult(Result):
|
||||
"""The client's response to a sampling/create_message request from the server."""
|
||||
|
||||
role: Role
|
||||
content: TextContent | ImageContent
|
||||
content: TextContent | ImageContent | AudioContent
|
||||
model: str
|
||||
"""The name of the model that generated the message."""
|
||||
stopReason: StopReason | None = None
|
||||
|
||||
@@ -12,6 +12,7 @@ from mcp.client.session import ClientSession
|
||||
from mcp.server.lowlevel import Server
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.types import (
|
||||
AudioContent,
|
||||
EmbeddedResource,
|
||||
ImageContent,
|
||||
TextContent,
|
||||
@@ -37,7 +38,7 @@ async def test_notification_validation_error(tmp_path: Path):
|
||||
@server.call_tool()
|
||||
async def slow_tool(
|
||||
name: str, arg
|
||||
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
|
||||
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
|
||||
nonlocal request_count
|
||||
request_count += 1
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from mcp.shared.memory import (
|
||||
create_connected_server_and_client_session as client_session,
|
||||
)
|
||||
from mcp.types import (
|
||||
AudioContent,
|
||||
BlobResourceContents,
|
||||
ImageContent,
|
||||
TextContent,
|
||||
@@ -207,10 +208,11 @@ def image_tool_fn(path: str) -> Image:
|
||||
return Image(path)
|
||||
|
||||
|
||||
def mixed_content_tool_fn() -> list[TextContent | ImageContent]:
|
||||
def mixed_content_tool_fn() -> list[TextContent | ImageContent | AudioContent]:
|
||||
return [
|
||||
TextContent(type="text", text="Hello"),
|
||||
ImageContent(type="image", data="abc", mimeType="image/png"),
|
||||
AudioContent(type="audio", data="def", mimeType="audio/wav"),
|
||||
]
|
||||
|
||||
|
||||
@@ -312,14 +314,16 @@ class TestServerTools:
|
||||
mcp.add_tool(mixed_content_tool_fn)
|
||||
async with client_session(mcp._mcp_server) as client:
|
||||
result = await client.call_tool("mixed_content_tool_fn", {})
|
||||
assert len(result.content) == 2
|
||||
content1 = result.content[0]
|
||||
content2 = result.content[1]
|
||||
assert len(result.content) == 3
|
||||
content1, content2, content3 = result.content
|
||||
assert isinstance(content1, TextContent)
|
||||
assert content1.text == "Hello"
|
||||
assert isinstance(content2, ImageContent)
|
||||
assert content2.mimeType == "image/png"
|
||||
assert content2.data == "abc"
|
||||
assert isinstance(content3, AudioContent)
|
||||
assert content3.mimeType == "audio/wav"
|
||||
assert content3.data == "def"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_mixed_list_with_image(self, tmp_path: Path):
|
||||
|
||||
Reference in New Issue
Block a user