feat: support audio content (#725)

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
Luca Chang
2025-06-07 07:32:11 -07:00
committed by GitHub
parent 2bce10bdb1
commit 7123556a34
9 changed files with 60 additions and 19 deletions

View File

@@ -43,7 +43,12 @@ def main(
@app.call_tool() @app.call_tool()
async def call_tool( async def call_tool(
name: str, arguments: dict name: str, arguments: dict
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: ) -> list[
types.TextContent
| types.ImageContent
| types.AudioContent
| types.EmbeddedResource
]:
ctx = app.request_context ctx = app.request_context
interval = arguments.get("interval", 1.0) interval = arguments.get("interval", 1.0)
count = arguments.get("count", 5) count = arguments.get("count", 5)

View File

@@ -47,7 +47,12 @@ def main(
@app.call_tool() @app.call_tool()
async def call_tool( async def call_tool(
name: str, arguments: dict name: str, arguments: dict
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: ) -> list[
types.TextContent
| types.ImageContent
| types.AudioContent
| types.EmbeddedResource
]:
ctx = app.request_context ctx = app.request_context
interval = arguments.get("interval", 1.0) interval = arguments.get("interval", 1.0)
count = arguments.get("count", 5) count = arguments.get("count", 5)

View File

@@ -7,7 +7,9 @@ from mcp.shared._httpx_utils import create_mcp_http_client
async def fetch_website( async def fetch_website(
url: str, url: str,
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: ) -> list[
types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource
]:
headers = { headers = {
"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)" "User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"
} }
@@ -31,7 +33,12 @@ def main(port: int, transport: str) -> int:
@app.call_tool() @app.call_tool()
async def fetch_tool( async def fetch_tool(
name: str, arguments: dict name: str, arguments: dict
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: ) -> list[
types.TextContent
| types.ImageContent
| types.AudioContent
| types.EmbeddedResource
]:
if name != "fetch": if name != "fetch":
raise ValueError(f"Unknown tool: {name}") raise ValueError(f"Unknown tool: {name}")
if "url" not in arguments: if "url" not in arguments:

View File

@@ -7,9 +7,9 @@ from typing import Any, Literal
import pydantic_core import pydantic_core
from pydantic import BaseModel, Field, TypeAdapter, validate_call from pydantic import BaseModel, Field, TypeAdapter, validate_call
from mcp.types import EmbeddedResource, ImageContent, TextContent from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent
CONTENT_TYPES = TextContent | ImageContent | EmbeddedResource CONTENT_TYPES = TextContent | ImageContent | AudioContent | EmbeddedResource
class Message(BaseModel): class Message(BaseModel):

View File

@@ -52,6 +52,7 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.shared.context import LifespanContextT, RequestContext, RequestT from mcp.shared.context import LifespanContextT, RequestContext, RequestT
from mcp.types import ( from mcp.types import (
AnyFunction, AnyFunction,
AudioContent,
EmbeddedResource, EmbeddedResource,
GetPromptResult, GetPromptResult,
ImageContent, ImageContent,
@@ -275,7 +276,7 @@ class FastMCP:
async def call_tool( async def call_tool(
self, name: str, arguments: dict[str, Any] self, name: str, arguments: dict[str, Any]
) -> Sequence[TextContent | ImageContent | EmbeddedResource]: ) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
"""Call a tool by name with arguments.""" """Call a tool by name with arguments."""
context = self.get_context() context = self.get_context()
result = await self._tool_manager.call_tool(name, arguments, context=context) result = await self._tool_manager.call_tool(name, arguments, context=context)
@@ -875,12 +876,12 @@ class FastMCP:
def _convert_to_content( def _convert_to_content(
result: Any, result: Any,
) -> Sequence[TextContent | ImageContent | EmbeddedResource]: ) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
"""Convert a result to a sequence of content objects.""" """Convert a result to a sequence of content objects."""
if result is None: if result is None:
return [] return []
if isinstance(result, TextContent | ImageContent | EmbeddedResource): if isinstance(result, TextContent | ImageContent | AudioContent | EmbeddedResource):
return [result] return [result]
if isinstance(result, Image): if isinstance(result, Image):

View File

@@ -405,7 +405,10 @@ class Server(Generic[LifespanResultT, RequestT]):
..., ...,
Awaitable[ Awaitable[
Iterable[ Iterable[
types.TextContent | types.ImageContent | types.EmbeddedResource types.TextContent
| types.ImageContent
| types.AudioContent
| types.EmbeddedResource
] ]
], ],
], ],

View File

@@ -657,11 +657,26 @@ class ImageContent(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
class AudioContent(BaseModel):
"""Audio content for a message."""
type: Literal["audio"]
data: str
"""The base64-encoded audio data."""
mimeType: str
"""
The MIME type of the audio. Different providers may support different
audio types.
"""
annotations: Annotations | None = None
model_config = ConfigDict(extra="allow")
class SamplingMessage(BaseModel): class SamplingMessage(BaseModel):
"""Describes a message issued to or received from an LLM API.""" """Describes a message issued to or received from an LLM API."""
role: Role role: Role
content: TextContent | ImageContent content: TextContent | ImageContent | AudioContent
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@@ -683,7 +698,7 @@ class PromptMessage(BaseModel):
"""Describes a message returned as part of a prompt.""" """Describes a message returned as part of a prompt."""
role: Role role: Role
content: TextContent | ImageContent | EmbeddedResource content: TextContent | ImageContent | AudioContent | EmbeddedResource
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@@ -801,7 +816,7 @@ class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
class CallToolResult(Result): class CallToolResult(Result):
"""The server's response to a tool call.""" """The server's response to a tool call."""
content: list[TextContent | ImageContent | EmbeddedResource] content: list[TextContent | ImageContent | AudioContent | EmbeddedResource]
isError: bool = False isError: bool = False
@@ -965,7 +980,7 @@ class CreateMessageResult(Result):
"""The client's response to a sampling/create_message request from the server.""" """The client's response to a sampling/create_message request from the server."""
role: Role role: Role
content: TextContent | ImageContent content: TextContent | ImageContent | AudioContent
model: str model: str
"""The name of the model that generated the message.""" """The name of the model that generated the message."""
stopReason: StopReason | None = None stopReason: StopReason | None = None

View File

@@ -12,6 +12,7 @@ from mcp.client.session import ClientSession
from mcp.server.lowlevel import Server from mcp.server.lowlevel import Server
from mcp.shared.exceptions import McpError from mcp.shared.exceptions import McpError
from mcp.types import ( from mcp.types import (
AudioContent,
EmbeddedResource, EmbeddedResource,
ImageContent, ImageContent,
TextContent, TextContent,
@@ -37,7 +38,7 @@ async def test_notification_validation_error(tmp_path: Path):
@server.call_tool() @server.call_tool()
async def slow_tool( async def slow_tool(
name: str, arg name: str, arg
) -> Sequence[TextContent | ImageContent | EmbeddedResource]: ) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
nonlocal request_count nonlocal request_count
request_count += 1 request_count += 1

View File

@@ -16,6 +16,7 @@ from mcp.shared.memory import (
create_connected_server_and_client_session as client_session, create_connected_server_and_client_session as client_session,
) )
from mcp.types import ( from mcp.types import (
AudioContent,
BlobResourceContents, BlobResourceContents,
ImageContent, ImageContent,
TextContent, TextContent,
@@ -207,10 +208,11 @@ def image_tool_fn(path: str) -> Image:
return Image(path) return Image(path)
def mixed_content_tool_fn() -> list[TextContent | ImageContent]: def mixed_content_tool_fn() -> list[TextContent | ImageContent | AudioContent]:
return [ return [
TextContent(type="text", text="Hello"), TextContent(type="text", text="Hello"),
ImageContent(type="image", data="abc", mimeType="image/png"), ImageContent(type="image", data="abc", mimeType="image/png"),
AudioContent(type="audio", data="def", mimeType="audio/wav"),
] ]
@@ -312,14 +314,16 @@ class TestServerTools:
mcp.add_tool(mixed_content_tool_fn) mcp.add_tool(mixed_content_tool_fn)
async with client_session(mcp._mcp_server) as client: async with client_session(mcp._mcp_server) as client:
result = await client.call_tool("mixed_content_tool_fn", {}) result = await client.call_tool("mixed_content_tool_fn", {})
assert len(result.content) == 2 assert len(result.content) == 3
content1 = result.content[0] content1, content2, content3 = result.content
content2 = result.content[1]
assert isinstance(content1, TextContent) assert isinstance(content1, TextContent)
assert content1.text == "Hello" assert content1.text == "Hello"
assert isinstance(content2, ImageContent) assert isinstance(content2, ImageContent)
assert content2.mimeType == "image/png" assert content2.mimeType == "image/png"
assert content2.data == "abc" assert content2.data == "abc"
assert isinstance(content3, AudioContent)
assert content3.mimeType == "audio/wav"
assert content3.data == "def"
@pytest.mark.anyio @pytest.mark.anyio
async def test_tool_mixed_list_with_image(self, tmp_path: Path): async def test_tool_mixed_list_with_image(self, tmp_path: Path):