mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
feat: support audio content (#725)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
@@ -43,7 +43,12 @@ def main(
|
|||||||
@app.call_tool()
|
@app.call_tool()
|
||||||
async def call_tool(
|
async def call_tool(
|
||||||
name: str, arguments: dict
|
name: str, arguments: dict
|
||||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
) -> list[
|
||||||
|
types.TextContent
|
||||||
|
| types.ImageContent
|
||||||
|
| types.AudioContent
|
||||||
|
| types.EmbeddedResource
|
||||||
|
]:
|
||||||
ctx = app.request_context
|
ctx = app.request_context
|
||||||
interval = arguments.get("interval", 1.0)
|
interval = arguments.get("interval", 1.0)
|
||||||
count = arguments.get("count", 5)
|
count = arguments.get("count", 5)
|
||||||
|
|||||||
@@ -47,7 +47,12 @@ def main(
|
|||||||
@app.call_tool()
|
@app.call_tool()
|
||||||
async def call_tool(
|
async def call_tool(
|
||||||
name: str, arguments: dict
|
name: str, arguments: dict
|
||||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
) -> list[
|
||||||
|
types.TextContent
|
||||||
|
| types.ImageContent
|
||||||
|
| types.AudioContent
|
||||||
|
| types.EmbeddedResource
|
||||||
|
]:
|
||||||
ctx = app.request_context
|
ctx = app.request_context
|
||||||
interval = arguments.get("interval", 1.0)
|
interval = arguments.get("interval", 1.0)
|
||||||
count = arguments.get("count", 5)
|
count = arguments.get("count", 5)
|
||||||
|
|||||||
@@ -7,7 +7,9 @@ from mcp.shared._httpx_utils import create_mcp_http_client
|
|||||||
|
|
||||||
async def fetch_website(
|
async def fetch_website(
|
||||||
url: str,
|
url: str,
|
||||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
) -> list[
|
||||||
|
types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource
|
||||||
|
]:
|
||||||
headers = {
|
headers = {
|
||||||
"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"
|
"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"
|
||||||
}
|
}
|
||||||
@@ -31,7 +33,12 @@ def main(port: int, transport: str) -> int:
|
|||||||
@app.call_tool()
|
@app.call_tool()
|
||||||
async def fetch_tool(
|
async def fetch_tool(
|
||||||
name: str, arguments: dict
|
name: str, arguments: dict
|
||||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
) -> list[
|
||||||
|
types.TextContent
|
||||||
|
| types.ImageContent
|
||||||
|
| types.AudioContent
|
||||||
|
| types.EmbeddedResource
|
||||||
|
]:
|
||||||
if name != "fetch":
|
if name != "fetch":
|
||||||
raise ValueError(f"Unknown tool: {name}")
|
raise ValueError(f"Unknown tool: {name}")
|
||||||
if "url" not in arguments:
|
if "url" not in arguments:
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ from typing import Any, Literal
|
|||||||
import pydantic_core
|
import pydantic_core
|
||||||
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
||||||
|
|
||||||
from mcp.types import EmbeddedResource, ImageContent, TextContent
|
from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent
|
||||||
|
|
||||||
CONTENT_TYPES = TextContent | ImageContent | EmbeddedResource
|
CONTENT_TYPES = TextContent | ImageContent | AudioContent | EmbeddedResource
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
|||||||
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
|
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
AnyFunction,
|
AnyFunction,
|
||||||
|
AudioContent,
|
||||||
EmbeddedResource,
|
EmbeddedResource,
|
||||||
GetPromptResult,
|
GetPromptResult,
|
||||||
ImageContent,
|
ImageContent,
|
||||||
@@ -275,7 +276,7 @@ class FastMCP:
|
|||||||
|
|
||||||
async def call_tool(
|
async def call_tool(
|
||||||
self, name: str, arguments: dict[str, Any]
|
self, name: str, arguments: dict[str, Any]
|
||||||
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
|
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
|
||||||
"""Call a tool by name with arguments."""
|
"""Call a tool by name with arguments."""
|
||||||
context = self.get_context()
|
context = self.get_context()
|
||||||
result = await self._tool_manager.call_tool(name, arguments, context=context)
|
result = await self._tool_manager.call_tool(name, arguments, context=context)
|
||||||
@@ -875,12 +876,12 @@ class FastMCP:
|
|||||||
|
|
||||||
def _convert_to_content(
|
def _convert_to_content(
|
||||||
result: Any,
|
result: Any,
|
||||||
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
|
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
|
||||||
"""Convert a result to a sequence of content objects."""
|
"""Convert a result to a sequence of content objects."""
|
||||||
if result is None:
|
if result is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if isinstance(result, TextContent | ImageContent | EmbeddedResource):
|
if isinstance(result, TextContent | ImageContent | AudioContent | EmbeddedResource):
|
||||||
return [result]
|
return [result]
|
||||||
|
|
||||||
if isinstance(result, Image):
|
if isinstance(result, Image):
|
||||||
|
|||||||
@@ -405,7 +405,10 @@ class Server(Generic[LifespanResultT, RequestT]):
|
|||||||
...,
|
...,
|
||||||
Awaitable[
|
Awaitable[
|
||||||
Iterable[
|
Iterable[
|
||||||
types.TextContent | types.ImageContent | types.EmbeddedResource
|
types.TextContent
|
||||||
|
| types.ImageContent
|
||||||
|
| types.AudioContent
|
||||||
|
| types.EmbeddedResource
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -657,11 +657,26 @@ class ImageContent(BaseModel):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
class AudioContent(BaseModel):
|
||||||
|
"""Audio content for a message."""
|
||||||
|
|
||||||
|
type: Literal["audio"]
|
||||||
|
data: str
|
||||||
|
"""The base64-encoded audio data."""
|
||||||
|
mimeType: str
|
||||||
|
"""
|
||||||
|
The MIME type of the audio. Different providers may support different
|
||||||
|
audio types.
|
||||||
|
"""
|
||||||
|
annotations: Annotations | None = None
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
class SamplingMessage(BaseModel):
|
class SamplingMessage(BaseModel):
|
||||||
"""Describes a message issued to or received from an LLM API."""
|
"""Describes a message issued to or received from an LLM API."""
|
||||||
|
|
||||||
role: Role
|
role: Role
|
||||||
content: TextContent | ImageContent
|
content: TextContent | ImageContent | AudioContent
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
@@ -683,7 +698,7 @@ class PromptMessage(BaseModel):
|
|||||||
"""Describes a message returned as part of a prompt."""
|
"""Describes a message returned as part of a prompt."""
|
||||||
|
|
||||||
role: Role
|
role: Role
|
||||||
content: TextContent | ImageContent | EmbeddedResource
|
content: TextContent | ImageContent | AudioContent | EmbeddedResource
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
@@ -801,7 +816,7 @@ class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
|
|||||||
class CallToolResult(Result):
|
class CallToolResult(Result):
|
||||||
"""The server's response to a tool call."""
|
"""The server's response to a tool call."""
|
||||||
|
|
||||||
content: list[TextContent | ImageContent | EmbeddedResource]
|
content: list[TextContent | ImageContent | AudioContent | EmbeddedResource]
|
||||||
isError: bool = False
|
isError: bool = False
|
||||||
|
|
||||||
|
|
||||||
@@ -965,7 +980,7 @@ class CreateMessageResult(Result):
|
|||||||
"""The client's response to a sampling/create_message request from the server."""
|
"""The client's response to a sampling/create_message request from the server."""
|
||||||
|
|
||||||
role: Role
|
role: Role
|
||||||
content: TextContent | ImageContent
|
content: TextContent | ImageContent | AudioContent
|
||||||
model: str
|
model: str
|
||||||
"""The name of the model that generated the message."""
|
"""The name of the model that generated the message."""
|
||||||
stopReason: StopReason | None = None
|
stopReason: StopReason | None = None
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from mcp.client.session import ClientSession
|
|||||||
from mcp.server.lowlevel import Server
|
from mcp.server.lowlevel import Server
|
||||||
from mcp.shared.exceptions import McpError
|
from mcp.shared.exceptions import McpError
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
|
AudioContent,
|
||||||
EmbeddedResource,
|
EmbeddedResource,
|
||||||
ImageContent,
|
ImageContent,
|
||||||
TextContent,
|
TextContent,
|
||||||
@@ -37,7 +38,7 @@ async def test_notification_validation_error(tmp_path: Path):
|
|||||||
@server.call_tool()
|
@server.call_tool()
|
||||||
async def slow_tool(
|
async def slow_tool(
|
||||||
name: str, arg
|
name: str, arg
|
||||||
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
|
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
|
||||||
nonlocal request_count
|
nonlocal request_count
|
||||||
request_count += 1
|
request_count += 1
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from mcp.shared.memory import (
|
|||||||
create_connected_server_and_client_session as client_session,
|
create_connected_server_and_client_session as client_session,
|
||||||
)
|
)
|
||||||
from mcp.types import (
|
from mcp.types import (
|
||||||
|
AudioContent,
|
||||||
BlobResourceContents,
|
BlobResourceContents,
|
||||||
ImageContent,
|
ImageContent,
|
||||||
TextContent,
|
TextContent,
|
||||||
@@ -207,10 +208,11 @@ def image_tool_fn(path: str) -> Image:
|
|||||||
return Image(path)
|
return Image(path)
|
||||||
|
|
||||||
|
|
||||||
def mixed_content_tool_fn() -> list[TextContent | ImageContent]:
|
def mixed_content_tool_fn() -> list[TextContent | ImageContent | AudioContent]:
|
||||||
return [
|
return [
|
||||||
TextContent(type="text", text="Hello"),
|
TextContent(type="text", text="Hello"),
|
||||||
ImageContent(type="image", data="abc", mimeType="image/png"),
|
ImageContent(type="image", data="abc", mimeType="image/png"),
|
||||||
|
AudioContent(type="audio", data="def", mimeType="audio/wav"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -312,14 +314,16 @@ class TestServerTools:
|
|||||||
mcp.add_tool(mixed_content_tool_fn)
|
mcp.add_tool(mixed_content_tool_fn)
|
||||||
async with client_session(mcp._mcp_server) as client:
|
async with client_session(mcp._mcp_server) as client:
|
||||||
result = await client.call_tool("mixed_content_tool_fn", {})
|
result = await client.call_tool("mixed_content_tool_fn", {})
|
||||||
assert len(result.content) == 2
|
assert len(result.content) == 3
|
||||||
content1 = result.content[0]
|
content1, content2, content3 = result.content
|
||||||
content2 = result.content[1]
|
|
||||||
assert isinstance(content1, TextContent)
|
assert isinstance(content1, TextContent)
|
||||||
assert content1.text == "Hello"
|
assert content1.text == "Hello"
|
||||||
assert isinstance(content2, ImageContent)
|
assert isinstance(content2, ImageContent)
|
||||||
assert content2.mimeType == "image/png"
|
assert content2.mimeType == "image/png"
|
||||||
assert content2.data == "abc"
|
assert content2.data == "abc"
|
||||||
|
assert isinstance(content3, AudioContent)
|
||||||
|
assert content3.mimeType == "audio/wav"
|
||||||
|
assert content3.data == "def"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_tool_mixed_list_with_image(self, tmp_path: Path):
|
async def test_tool_mixed_list_with_image(self, tmp_path: Path):
|
||||||
|
|||||||
Reference in New Issue
Block a user