mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 14:54:24 +01:00
Integrate FastMCP
This commit integrates FastMCP, a high-level MCP server implementation originally written by Jeremiah Lowin, into the official MCP SDK. It also updates dependencies and adds new dev dependencies. It moves the existing SDK into a .lowlevel .
This commit is contained in:
@@ -1,500 +1,4 @@
|
||||
"""
|
||||
MCP Server Module
|
||||
from .lowlevel import Server, NotificationOptions
|
||||
from .fastmcp import FastMCP
|
||||
|
||||
This module provides a framework for creating an MCP (Model Context Protocol) server.
|
||||
It allows you to easily define and handle various types of requests and notifications
|
||||
in an asynchronous manner.
|
||||
|
||||
Usage:
|
||||
1. Create a Server instance:
|
||||
server = Server("your_server_name")
|
||||
|
||||
2. Define request handlers using decorators:
|
||||
@server.list_prompts()
|
||||
async def handle_list_prompts() -> list[types.Prompt]:
|
||||
# Implementation
|
||||
|
||||
@server.get_prompt()
|
||||
async def handle_get_prompt(
|
||||
name: str, arguments: dict[str, str] | None
|
||||
) -> types.GetPromptResult:
|
||||
# Implementation
|
||||
|
||||
@server.list_tools()
|
||||
async def handle_list_tools() -> list[types.Tool]:
|
||||
# Implementation
|
||||
|
||||
@server.call_tool()
|
||||
async def handle_call_tool(
|
||||
name: str, arguments: dict | None
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
# Implementation
|
||||
|
||||
@server.list_resource_templates()
|
||||
async def handle_list_resource_templates() -> list[types.ResourceTemplate]:
|
||||
# Implementation
|
||||
|
||||
3. Define notification handlers if needed:
|
||||
@server.progress_notification()
|
||||
async def handle_progress(
|
||||
progress_token: str | int, progress: float, total: float | None
|
||||
) -> None:
|
||||
# Implementation
|
||||
|
||||
4. Run the server:
|
||||
async def main():
|
||||
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
|
||||
await server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
InitializationOptions(
|
||||
server_name="your_server_name",
|
||||
server_version="your_version",
|
||||
capabilities=server.get_capabilities(
|
||||
notification_options=NotificationOptions(),
|
||||
experimental_capabilities={},
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
The Server class provides methods to register handlers for various MCP requests and
|
||||
notifications. It automatically manages the request context and handles incoming
|
||||
messages from the client.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, Sequence
|
||||
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.server.stdio import stdio_server as stdio_server
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.session import RequestResponder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
request_ctx: contextvars.ContextVar[RequestContext[ServerSession]] = (
|
||||
contextvars.ContextVar("request_ctx")
|
||||
)
|
||||
|
||||
|
||||
class NotificationOptions:
|
||||
def __init__(
|
||||
self,
|
||||
prompts_changed: bool = False,
|
||||
resources_changed: bool = False,
|
||||
tools_changed: bool = False,
|
||||
):
|
||||
self.prompts_changed = prompts_changed
|
||||
self.resources_changed = resources_changed
|
||||
self.tools_changed = tools_changed
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.request_handlers: dict[
|
||||
type, Callable[..., Awaitable[types.ServerResult]]
|
||||
] = {
|
||||
types.PingRequest: _ping_handler,
|
||||
}
|
||||
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
||||
self.notification_options = NotificationOptions()
|
||||
logger.debug(f"Initializing server '{name}'")
|
||||
|
||||
def create_initialization_options(
|
||||
self,
|
||||
notification_options: NotificationOptions | None = None,
|
||||
experimental_capabilities: dict[str, dict[str, Any]] | None = None,
|
||||
) -> InitializationOptions:
|
||||
"""Create initialization options from this server instance."""
|
||||
|
||||
def pkg_version(package: str) -> str:
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
v = version(package)
|
||||
if v is not None:
|
||||
return v
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "unknown"
|
||||
|
||||
return InitializationOptions(
|
||||
server_name=self.name,
|
||||
server_version=pkg_version("mcp"),
|
||||
capabilities=self.get_capabilities(
|
||||
notification_options or NotificationOptions(),
|
||||
experimental_capabilities or {},
|
||||
),
|
||||
)
|
||||
|
||||
def get_capabilities(
|
||||
self,
|
||||
notification_options: NotificationOptions,
|
||||
experimental_capabilities: dict[str, dict[str, Any]],
|
||||
) -> types.ServerCapabilities:
|
||||
"""Convert existing handlers to a ServerCapabilities object."""
|
||||
prompts_capability = None
|
||||
resources_capability = None
|
||||
tools_capability = None
|
||||
logging_capability = None
|
||||
|
||||
# Set prompt capabilities if handler exists
|
||||
if types.ListPromptsRequest in self.request_handlers:
|
||||
prompts_capability = types.PromptsCapability(
|
||||
listChanged=notification_options.prompts_changed
|
||||
)
|
||||
|
||||
# Set resource capabilities if handler exists
|
||||
if types.ListResourcesRequest in self.request_handlers:
|
||||
resources_capability = types.ResourcesCapability(
|
||||
subscribe=False, listChanged=notification_options.resources_changed
|
||||
)
|
||||
|
||||
# Set tool capabilities if handler exists
|
||||
if types.ListToolsRequest in self.request_handlers:
|
||||
tools_capability = types.ToolsCapability(
|
||||
listChanged=notification_options.tools_changed
|
||||
)
|
||||
|
||||
# Set logging capabilities if handler exists
|
||||
if types.SetLevelRequest in self.request_handlers:
|
||||
logging_capability = types.LoggingCapability()
|
||||
|
||||
return types.ServerCapabilities(
|
||||
prompts=prompts_capability,
|
||||
resources=resources_capability,
|
||||
tools=tools_capability,
|
||||
logging=logging_capability,
|
||||
experimental=experimental_capabilities,
|
||||
)
|
||||
|
||||
@property
|
||||
def request_context(self) -> RequestContext[ServerSession]:
|
||||
"""If called outside of a request context, this will raise a LookupError."""
|
||||
return request_ctx.get()
|
||||
|
||||
def list_prompts(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[types.Prompt]]]):
|
||||
logger.debug("Registering handler for PromptListRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
prompts = await func()
|
||||
return types.ServerResult(types.ListPromptsResult(prompts=prompts))
|
||||
|
||||
self.request_handlers[types.ListPromptsRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_prompt(self):
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[str, dict[str, str] | None], Awaitable[types.GetPromptResult]
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for GetPromptRequest")
|
||||
|
||||
async def handler(req: types.GetPromptRequest):
|
||||
prompt_get = await func(req.params.name, req.params.arguments)
|
||||
return types.ServerResult(prompt_get)
|
||||
|
||||
self.request_handlers[types.GetPromptRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def list_resources(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[types.Resource]]]):
|
||||
logger.debug("Registering handler for ListResourcesRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
resources = await func()
|
||||
return types.ServerResult(
|
||||
types.ListResourcesResult(resources=resources)
|
||||
)
|
||||
|
||||
self.request_handlers[types.ListResourcesRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def list_resource_templates(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]):
|
||||
logger.debug("Registering handler for ListResourceTemplatesRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
templates = await func()
|
||||
return types.ServerResult(
|
||||
types.ListResourceTemplatesResult(resourceTemplates=templates)
|
||||
)
|
||||
|
||||
self.request_handlers[types.ListResourceTemplatesRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def read_resource(self):
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]):
|
||||
logger.debug("Registering handler for ReadResourceRequest")
|
||||
|
||||
async def handler(req: types.ReadResourceRequest):
|
||||
result = await func(req.params.uri)
|
||||
match result:
|
||||
case str(s):
|
||||
content = types.TextResourceContents(
|
||||
uri=req.params.uri,
|
||||
text=s,
|
||||
mimeType="text/plain",
|
||||
)
|
||||
case bytes(b):
|
||||
import base64
|
||||
|
||||
content = types.BlobResourceContents(
|
||||
uri=req.params.uri,
|
||||
blob=base64.urlsafe_b64encode(b).decode(),
|
||||
mimeType="application/octet-stream",
|
||||
)
|
||||
|
||||
return types.ServerResult(
|
||||
types.ReadResourceResult(
|
||||
contents=[content],
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[types.ReadResourceRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def set_logging_level(self):
|
||||
def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]):
|
||||
logger.debug("Registering handler for SetLevelRequest")
|
||||
|
||||
async def handler(req: types.SetLevelRequest):
|
||||
await func(req.params.level)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
self.request_handlers[types.SetLevelRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def subscribe_resource(self):
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||
logger.debug("Registering handler for SubscribeRequest")
|
||||
|
||||
async def handler(req: types.SubscribeRequest):
|
||||
await func(req.params.uri)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
self.request_handlers[types.SubscribeRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def unsubscribe_resource(self):
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||
logger.debug("Registering handler for UnsubscribeRequest")
|
||||
|
||||
async def handler(req: types.UnsubscribeRequest):
|
||||
await func(req.params.uri)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
self.request_handlers[types.UnsubscribeRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def list_tools(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[types.Tool]]]):
|
||||
logger.debug("Registering handler for ListToolsRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
tools = await func()
|
||||
return types.ServerResult(types.ListToolsResult(tools=tools))
|
||||
|
||||
self.request_handlers[types.ListToolsRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def call_tool(self):
|
||||
def decorator(
|
||||
func: Callable[
|
||||
...,
|
||||
Awaitable[
|
||||
Sequence[
|
||||
types.TextContent | types.ImageContent | types.EmbeddedResource
|
||||
]
|
||||
],
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for CallToolRequest")
|
||||
|
||||
async def handler(req: types.CallToolRequest):
|
||||
try:
|
||||
results = await func(req.params.name, (req.params.arguments or {}))
|
||||
return types.ServerResult(
|
||||
types.CallToolResult(content=list(results), isError=False)
|
||||
)
|
||||
except Exception as e:
|
||||
return types.ServerResult(
|
||||
types.CallToolResult(
|
||||
content=[types.TextContent(type="text", text=str(e))],
|
||||
isError=True,
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[types.CallToolRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def progress_notification(self):
|
||||
def decorator(
|
||||
func: Callable[[str | int, float, float | None], Awaitable[None]],
|
||||
):
|
||||
logger.debug("Registering handler for ProgressNotification")
|
||||
|
||||
async def handler(req: types.ProgressNotification):
|
||||
await func(
|
||||
req.params.progressToken, req.params.progress, req.params.total
|
||||
)
|
||||
|
||||
self.notification_handlers[types.ProgressNotification] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def completion(self):
|
||||
"""Provides completions for prompts and resource templates"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[
|
||||
types.PromptReference | types.ResourceReference,
|
||||
types.CompletionArgument,
|
||||
],
|
||||
Awaitable[types.Completion | None],
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for CompleteRequest")
|
||||
|
||||
async def handler(req: types.CompleteRequest):
|
||||
completion = await func(req.params.ref, req.params.argument)
|
||||
return types.ServerResult(
|
||||
types.CompleteResult(
|
||||
completion=completion
|
||||
if completion is not None
|
||||
else types.Completion(values=[], total=None, hasMore=None),
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[types.CompleteRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def run(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
||||
initialization_options: InitializationOptions,
|
||||
# When False, exceptions are returned as messages to the client.
|
||||
# When True, exceptions are raised, which will cause the server to shut down
|
||||
# but also make tracing exceptions much easier during testing and when using
|
||||
# in-process servers.
|
||||
raise_exceptions: bool = False,
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
async with ServerSession(
|
||||
read_stream, write_stream, initialization_options
|
||||
) as session:
|
||||
async for message in session.incoming_messages:
|
||||
logger.debug(f"Received message: {message}")
|
||||
|
||||
match message:
|
||||
case RequestResponder(request=types.ClientRequest(root=req)):
|
||||
logger.info(
|
||||
f"Processing request of type {type(req).__name__}"
|
||||
)
|
||||
if type(req) in self.request_handlers:
|
||||
handler = self.request_handlers[type(req)]
|
||||
logger.debug(
|
||||
f"Dispatching request of type {type(req).__name__}"
|
||||
)
|
||||
|
||||
token = None
|
||||
try:
|
||||
# Set our global state that can be retrieved via
|
||||
# app.get_request_context()
|
||||
token = request_ctx.set(
|
||||
RequestContext(
|
||||
message.request_id,
|
||||
message.request_meta,
|
||||
session,
|
||||
)
|
||||
)
|
||||
response = await handler(req)
|
||||
except McpError as err:
|
||||
response = err.error
|
||||
except Exception as err:
|
||||
if raise_exceptions:
|
||||
raise err
|
||||
response = types.ErrorData(
|
||||
code=0, message=str(err), data=None
|
||||
)
|
||||
finally:
|
||||
# Reset the global state after we are done
|
||||
if token is not None:
|
||||
request_ctx.reset(token)
|
||||
|
||||
await message.respond(response)
|
||||
else:
|
||||
await message.respond(
|
||||
types.ErrorData(
|
||||
code=types.METHOD_NOT_FOUND,
|
||||
message="Method not found",
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("Response sent")
|
||||
case types.ClientNotification(root=notify):
|
||||
if type(notify) in self.notification_handlers:
|
||||
assert type(notify) in self.notification_handlers
|
||||
|
||||
handler = self.notification_handlers[type(notify)]
|
||||
logger.debug(
|
||||
f"Dispatching notification of type "
|
||||
f"{type(notify).__name__}"
|
||||
)
|
||||
|
||||
try:
|
||||
await handler(notify)
|
||||
except Exception as err:
|
||||
logger.error(
|
||||
f"Uncaught exception in notification handler: "
|
||||
f"{err}"
|
||||
)
|
||||
|
||||
for warning in w:
|
||||
logger.info(
|
||||
f"Warning: {warning.category.__name__}: {warning.message}"
|
||||
)
|
||||
|
||||
|
||||
async def _ping_handler(request: types.PingRequest) -> types.ServerResult:
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
__all__ = ["Server", "FastMCP", "NotificationOptions"]
|
||||
|
||||
8
src/mcp/server/fastmcp/__init__.py
Normal file
8
src/mcp/server/fastmcp/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""FastMCP - A more ergonomic interface for MCP servers."""
|
||||
|
||||
from importlib.metadata import version
|
||||
from .server import FastMCP, Context
|
||||
from .utilities.types import Image
|
||||
|
||||
__version__ = version("mcp")
|
||||
__all__ = ["FastMCP", "Context", "Image"]
|
||||
21
src/mcp/server/fastmcp/exceptions.py
Normal file
21
src/mcp/server/fastmcp/exceptions.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Custom exceptions for FastMCP."""
|
||||
|
||||
|
||||
class FastMCPError(Exception):
|
||||
"""Base error for FastMCP."""
|
||||
|
||||
|
||||
class ValidationError(FastMCPError):
|
||||
"""Error in validating parameters or return values."""
|
||||
|
||||
|
||||
class ResourceError(FastMCPError):
|
||||
"""Error in resource operations."""
|
||||
|
||||
|
||||
class ToolError(FastMCPError):
|
||||
"""Error in tool operations."""
|
||||
|
||||
|
||||
class InvalidSignature(Exception):
|
||||
"""Invalid signature for use with FastMCP."""
|
||||
4
src/mcp/server/fastmcp/prompts/__init__.py
Normal file
4
src/mcp/server/fastmcp/prompts/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import Prompt
|
||||
from .manager import PromptManager
|
||||
|
||||
__all__ = ["Prompt", "PromptManager"]
|
||||
166
src/mcp/server/fastmcp/prompts/base.py
Normal file
166
src/mcp/server/fastmcp/prompts/base.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Base classes for FastMCP prompts."""
|
||||
|
||||
import json
|
||||
from typing import Any, Literal, Sequence, Awaitable
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
||||
from mcp.types import TextContent, ImageContent, EmbeddedResource
|
||||
import pydantic_core
|
||||
|
||||
CONTENT_TYPES = TextContent | ImageContent | EmbeddedResource
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Base class for all prompt messages."""
|
||||
|
||||
role: Literal["user", "assistant"]
|
||||
content: CONTENT_TYPES
|
||||
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
|
||||
if isinstance(content, str):
|
||||
content = TextContent(type="text", text=content)
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
class UserMessage(Message):
|
||||
"""A message from the user."""
|
||||
|
||||
role: Literal["user", "assistant"] = "user"
|
||||
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
class AssistantMessage(Message):
|
||||
"""A message from the assistant."""
|
||||
|
||||
role: Literal["user", "assistant"] = "assistant"
|
||||
|
||||
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
message_validator = TypeAdapter(UserMessage | AssistantMessage)
|
||||
|
||||
SyncPromptResult = (
|
||||
str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
|
||||
)
|
||||
PromptResult = SyncPromptResult | Awaitable[SyncPromptResult]
|
||||
|
||||
|
||||
class PromptArgument(BaseModel):
|
||||
"""An argument that can be passed to a prompt."""
|
||||
|
||||
name: str = Field(description="Name of the argument")
|
||||
description: str | None = Field(
|
||||
None, description="Description of what the argument does"
|
||||
)
|
||||
required: bool = Field(
|
||||
default=False, description="Whether the argument is required"
|
||||
)
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
"""A prompt template that can be rendered with parameters."""
|
||||
|
||||
name: str = Field(description="Name of the prompt")
|
||||
description: str | None = Field(
|
||||
None, description="Description of what the prompt does"
|
||||
)
|
||||
arguments: list[PromptArgument] | None = Field(
|
||||
None, description="Arguments that can be passed to the prompt"
|
||||
)
|
||||
fn: Callable = Field(exclude=True)
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
fn: Callable[..., PromptResult],
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> "Prompt":
|
||||
"""Create a Prompt from a function.
|
||||
|
||||
The function can return:
|
||||
- A string (converted to a message)
|
||||
- A Message object
|
||||
- A dict (converted to a message)
|
||||
- A sequence of any of the above
|
||||
"""
|
||||
func_name = name or fn.__name__
|
||||
|
||||
if func_name == "<lambda>":
|
||||
raise ValueError("You must provide a name for lambda functions")
|
||||
|
||||
# Get schema from TypeAdapter - will fail if function isn't properly typed
|
||||
parameters = TypeAdapter(fn).json_schema()
|
||||
|
||||
# Convert parameters to PromptArguments
|
||||
arguments = []
|
||||
if "properties" in parameters:
|
||||
for param_name, param in parameters["properties"].items():
|
||||
required = param_name in parameters.get("required", [])
|
||||
arguments.append(
|
||||
PromptArgument(
|
||||
name=param_name,
|
||||
description=param.get("description"),
|
||||
required=required,
|
||||
)
|
||||
)
|
||||
|
||||
# ensure the arguments are properly cast
|
||||
fn = validate_call(fn)
|
||||
|
||||
return cls(
|
||||
name=func_name,
|
||||
description=description or fn.__doc__ or "",
|
||||
arguments=arguments,
|
||||
fn=fn,
|
||||
)
|
||||
|
||||
async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]:
|
||||
"""Render the prompt with arguments."""
|
||||
# Validate required arguments
|
||||
if self.arguments:
|
||||
required = {arg.name for arg in self.arguments if arg.required}
|
||||
provided = set(arguments or {})
|
||||
missing = required - provided
|
||||
if missing:
|
||||
raise ValueError(f"Missing required arguments: {missing}")
|
||||
|
||||
try:
|
||||
# Call function and check if result is a coroutine
|
||||
result = self.fn(**(arguments or {}))
|
||||
if inspect.iscoroutine(result):
|
||||
result = await result
|
||||
|
||||
# Validate messages
|
||||
if not isinstance(result, (list, tuple)):
|
||||
result = [result]
|
||||
|
||||
# Convert result to messages
|
||||
messages = []
|
||||
for msg in result:
|
||||
try:
|
||||
if isinstance(msg, Message):
|
||||
messages.append(msg)
|
||||
elif isinstance(msg, dict):
|
||||
msg = message_validator.validate_python(msg)
|
||||
messages.append(msg)
|
||||
elif isinstance(msg, str):
|
||||
messages.append(
|
||||
UserMessage(content=TextContent(type="text", text=msg))
|
||||
)
|
||||
else:
|
||||
msg = json.dumps(pydantic_core.to_jsonable_python(msg))
|
||||
messages.append(Message(role="user", content=msg))
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Could not convert prompt result to message: {msg}"
|
||||
)
|
||||
|
||||
return messages
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error rendering prompt {self.name}: {e}")
|
||||
50
src/mcp/server/fastmcp/prompts/manager.py
Normal file
50
src/mcp/server/fastmcp/prompts/manager.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Prompt management functionality."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from mcp.server.fastmcp.prompts.base import Message, Prompt
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""Manages FastMCP prompts."""
|
||||
|
||||
def __init__(self, warn_on_duplicate_prompts: bool = True):
|
||||
self._prompts: dict[str, Prompt] = {}
|
||||
self.warn_on_duplicate_prompts = warn_on_duplicate_prompts
|
||||
|
||||
def get_prompt(self, name: str) -> Prompt | None:
|
||||
"""Get prompt by name."""
|
||||
return self._prompts.get(name)
|
||||
|
||||
def list_prompts(self) -> list[Prompt]:
|
||||
"""List all registered prompts."""
|
||||
return list(self._prompts.values())
|
||||
|
||||
def add_prompt(
|
||||
self,
|
||||
prompt: Prompt,
|
||||
) -> Prompt:
|
||||
"""Add a prompt to the manager."""
|
||||
|
||||
# Check for duplicates
|
||||
existing = self._prompts.get(prompt.name)
|
||||
if existing:
|
||||
if self.warn_on_duplicate_prompts:
|
||||
logger.warning(f"Prompt already exists: {prompt.name}")
|
||||
return existing
|
||||
|
||||
self._prompts[prompt.name] = prompt
|
||||
return prompt
|
||||
|
||||
async def render_prompt(
|
||||
self, name: str, arguments: dict[str, Any] | None = None
|
||||
) -> list[Message]:
|
||||
"""Render a prompt by name with arguments."""
|
||||
prompt = self.get_prompt(name)
|
||||
if not prompt:
|
||||
raise ValueError(f"Unknown prompt: {name}")
|
||||
|
||||
return await prompt.render(arguments)
|
||||
33
src/mcp/server/fastmcp/prompts/prompt_manager.py
Normal file
33
src/mcp/server/fastmcp/prompts/prompt_manager.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Prompt management functionality."""
|
||||
|
||||
from mcp.server.fastmcp.prompts.base import Prompt
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""Manages FastMCP prompts."""
|
||||
|
||||
def __init__(self, warn_on_duplicate_prompts: bool = True):
|
||||
self._prompts: dict[str, Prompt] = {}
|
||||
self.warn_on_duplicate_prompts = warn_on_duplicate_prompts
|
||||
|
||||
def add_prompt(self, prompt: Prompt) -> Prompt:
|
||||
"""Add a prompt to the manager."""
|
||||
logger.debug(f"Adding prompt: {prompt.name}")
|
||||
existing = self._prompts.get(prompt.name)
|
||||
if existing:
|
||||
if self.warn_on_duplicate_prompts:
|
||||
logger.warning(f"Prompt already exists: {prompt.name}")
|
||||
return existing
|
||||
self._prompts[prompt.name] = prompt
|
||||
return prompt
|
||||
|
||||
def get_prompt(self, name: str) -> Prompt | None:
|
||||
"""Get prompt by name."""
|
||||
return self._prompts.get(name)
|
||||
|
||||
def list_prompts(self) -> list[Prompt]:
|
||||
"""List all registered prompts."""
|
||||
return list(self._prompts.values())
|
||||
23
src/mcp/server/fastmcp/resources/__init__.py
Normal file
23
src/mcp/server/fastmcp/resources/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from .base import Resource
|
||||
from .types import (
|
||||
TextResource,
|
||||
BinaryResource,
|
||||
FunctionResource,
|
||||
FileResource,
|
||||
HttpResource,
|
||||
DirectoryResource,
|
||||
)
|
||||
from .templates import ResourceTemplate
|
||||
from .resource_manager import ResourceManager
|
||||
|
||||
__all__ = [
|
||||
"Resource",
|
||||
"TextResource",
|
||||
"BinaryResource",
|
||||
"FunctionResource",
|
||||
"FileResource",
|
||||
"HttpResource",
|
||||
"DirectoryResource",
|
||||
"ResourceTemplate",
|
||||
"ResourceManager",
|
||||
]
|
||||
48
src/mcp/server/fastmcp/resources/base.py
Normal file
48
src/mcp/server/fastmcp/resources/base.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Base classes and interfaces for FastMCP resources."""
|
||||
|
||||
import abc
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import (
|
||||
AnyUrl,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
UrlConstraints,
|
||||
ValidationInfo,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
|
||||
class Resource(BaseModel, abc.ABC):
|
||||
"""Base class for all resources."""
|
||||
|
||||
model_config = ConfigDict(validate_default=True)
|
||||
|
||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(
|
||||
default=..., description="URI of the resource"
|
||||
)
|
||||
name: str | None = Field(description="Name of the resource", default=None)
|
||||
description: str | None = Field(
|
||||
description="Description of the resource", default=None
|
||||
)
|
||||
mime_type: str = Field(
|
||||
default="text/plain",
|
||||
description="MIME type of the resource content",
|
||||
pattern=r"^[a-zA-Z0-9]+/[a-zA-Z0-9\-+.]+$",
|
||||
)
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def set_default_name(cls, name: str | None, info: ValidationInfo) -> str:
|
||||
"""Set default name from URI if not provided."""
|
||||
if name:
|
||||
return name
|
||||
if uri := info.data.get("uri"):
|
||||
return str(uri)
|
||||
raise ValueError("Either name or uri must be provided")
|
||||
|
||||
@abc.abstractmethod
|
||||
async def read(self) -> str | bytes:
|
||||
"""Read the resource content."""
|
||||
pass
|
||||
95
src/mcp/server/fastmcp/resources/resource_manager.py
Normal file
95
src/mcp/server/fastmcp/resources/resource_manager.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Resource manager functionality."""
|
||||
|
||||
from typing import Callable
|
||||
from collections.abc import Iterable
|
||||
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp.server.fastmcp.resources.base import Resource
|
||||
from mcp.server.fastmcp.resources.templates import ResourceTemplate
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ResourceManager:
|
||||
"""Manages FastMCP resources."""
|
||||
|
||||
def __init__(self, warn_on_duplicate_resources: bool = True):
|
||||
self._resources: dict[str, Resource] = {}
|
||||
self._templates: dict[str, ResourceTemplate] = {}
|
||||
self.warn_on_duplicate_resources = warn_on_duplicate_resources
|
||||
|
||||
def add_resource(self, resource: Resource) -> Resource:
|
||||
"""Add a resource to the manager.
|
||||
|
||||
Args:
|
||||
resource: A Resource instance to add
|
||||
|
||||
Returns:
|
||||
The added resource. If a resource with the same URI already exists,
|
||||
returns the existing resource.
|
||||
"""
|
||||
logger.debug(
|
||||
"Adding resource",
|
||||
extra={
|
||||
"uri": resource.uri,
|
||||
"type": type(resource).__name__,
|
||||
"name": resource.name,
|
||||
},
|
||||
)
|
||||
existing = self._resources.get(str(resource.uri))
|
||||
if existing:
|
||||
if self.warn_on_duplicate_resources:
|
||||
logger.warning(f"Resource already exists: {resource.uri}")
|
||||
return existing
|
||||
self._resources[str(resource.uri)] = resource
|
||||
return resource
|
||||
|
||||
def add_template(
|
||||
self,
|
||||
fn: Callable,
|
||||
uri_template: str,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> ResourceTemplate:
|
||||
"""Add a template from a function."""
|
||||
template = ResourceTemplate.from_function(
|
||||
fn,
|
||||
uri_template=uri_template,
|
||||
name=name,
|
||||
description=description,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
self._templates[template.uri_template] = template
|
||||
return template
|
||||
|
||||
async def get_resource(self, uri: AnyUrl | str) -> Resource | None:
|
||||
"""Get resource by URI, checking concrete resources first, then templates."""
|
||||
uri_str = str(uri)
|
||||
logger.debug("Getting resource", extra={"uri": uri_str})
|
||||
|
||||
# First check concrete resources
|
||||
if resource := self._resources.get(uri_str):
|
||||
return resource
|
||||
|
||||
# Then check templates
|
||||
for template in self._templates.values():
|
||||
if params := template.matches(uri_str):
|
||||
try:
|
||||
return await template.create_resource(uri_str, params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error creating resource from template: {e}")
|
||||
|
||||
raise ValueError(f"Unknown resource: {uri}")
|
||||
|
||||
def list_resources(self) -> list[Resource]:
|
||||
"""List all registered resources."""
|
||||
logger.debug("Listing resources", extra={"count": len(self._resources)})
|
||||
return list(self._resources.values())
|
||||
|
||||
def list_templates(self) -> list[ResourceTemplate]:
|
||||
"""List all registered templates."""
|
||||
logger.debug("Listing templates", extra={"count": len(self._templates)})
|
||||
return list(self._templates.values())
|
||||
80
src/mcp/server/fastmcp/resources/templates.py
Normal file
80
src/mcp/server/fastmcp/resources/templates.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Resource template functionality."""
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from typing import Any, Callable
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter, validate_call
|
||||
|
||||
from mcp.server.fastmcp.resources.types import FunctionResource, Resource
|
||||
|
||||
|
||||
class ResourceTemplate(BaseModel):
|
||||
"""A template for dynamically creating resources."""
|
||||
|
||||
uri_template: str = Field(
|
||||
description="URI template with parameters (e.g. weather://{city}/current)"
|
||||
)
|
||||
name: str = Field(description="Name of the resource")
|
||||
description: str | None = Field(description="Description of what the resource does")
|
||||
mime_type: str = Field(
|
||||
default="text/plain", description="MIME type of the resource content"
|
||||
)
|
||||
fn: Callable = Field(exclude=True)
|
||||
parameters: dict = Field(description="JSON schema for function parameters")
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
fn: Callable,
|
||||
uri_template: str,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> "ResourceTemplate":
|
||||
"""Create a template from a function."""
|
||||
func_name = name or fn.__name__
|
||||
if func_name == "<lambda>":
|
||||
raise ValueError("You must provide a name for lambda functions")
|
||||
|
||||
# Get schema from TypeAdapter - will fail if function isn't properly typed
|
||||
parameters = TypeAdapter(fn).json_schema()
|
||||
|
||||
# ensure the arguments are properly cast
|
||||
fn = validate_call(fn)
|
||||
|
||||
return cls(
|
||||
uri_template=uri_template,
|
||||
name=func_name,
|
||||
description=description or fn.__doc__ or "",
|
||||
mime_type=mime_type or "text/plain",
|
||||
fn=fn,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
def matches(self, uri: str) -> dict[str, Any] | None:
|
||||
"""Check if URI matches template and extract parameters."""
|
||||
# Convert template to regex pattern
|
||||
pattern = self.uri_template.replace("{", "(?P<").replace("}", ">[^/]+)")
|
||||
match = re.match(f"^{pattern}$", uri)
|
||||
if match:
|
||||
return match.groupdict()
|
||||
return None
|
||||
|
||||
async def create_resource(self, uri: str, params: dict[str, Any]) -> Resource:
|
||||
"""Create a resource from the template with the given parameters."""
|
||||
try:
|
||||
# Call function and check if result is a coroutine
|
||||
result = self.fn(**params)
|
||||
if inspect.iscoroutine(result):
|
||||
result = await result
|
||||
|
||||
return FunctionResource(
|
||||
uri=uri, # type: ignore
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
mime_type=self.mime_type,
|
||||
fn=lambda: result, # Capture result in closure
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error creating resource from template: {e}")
|
||||
181
src/mcp/server/fastmcp/resources/types.py
Normal file
181
src/mcp/server/fastmcp/resources/types.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Concrete resource implementations."""
|
||||
|
||||
import anyio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
import httpx
|
||||
import pydantic.json
|
||||
import pydantic_core
|
||||
from pydantic import Field, ValidationInfo
|
||||
|
||||
from mcp.server.fastmcp.resources.base import Resource
|
||||
|
||||
|
||||
class TextResource(Resource):
|
||||
"""A resource that reads from a string."""
|
||||
|
||||
text: str = Field(description="Text content of the resource")
|
||||
|
||||
async def read(self) -> str:
|
||||
"""Read the text content."""
|
||||
return self.text
|
||||
|
||||
|
||||
class BinaryResource(Resource):
|
||||
"""A resource that reads from bytes."""
|
||||
|
||||
data: bytes = Field(description="Binary content of the resource")
|
||||
|
||||
async def read(self) -> bytes:
|
||||
"""Read the binary content."""
|
||||
return self.data
|
||||
|
||||
|
||||
class FunctionResource(Resource):
|
||||
"""A resource that defers data loading by wrapping a function.
|
||||
|
||||
The function is only called when the resource is read, allowing for lazy loading
|
||||
of potentially expensive data. This is particularly useful when listing resources,
|
||||
as the function won't be called until the resource is actually accessed.
|
||||
|
||||
The function can return:
|
||||
- str for text content (default)
|
||||
- bytes for binary content
|
||||
- other types will be converted to JSON
|
||||
"""
|
||||
|
||||
fn: Callable[[], Any] = Field(exclude=True)
|
||||
|
||||
async def read(self) -> str | bytes:
|
||||
"""Read the resource by calling the wrapped function."""
|
||||
try:
|
||||
result = self.fn()
|
||||
if isinstance(result, Resource):
|
||||
return await result.read()
|
||||
if isinstance(result, bytes):
|
||||
return result
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
try:
|
||||
return json.dumps(pydantic_core.to_jsonable_python(result))
|
||||
except (TypeError, pydantic_core.PydanticSerializationError):
|
||||
# If JSON serialization fails, try str()
|
||||
return str(result)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading resource {self.uri}: {e}")
|
||||
|
||||
|
||||
class FileResource(Resource):
|
||||
"""A resource that reads from a file.
|
||||
|
||||
Set is_binary=True to read file as binary data instead of text.
|
||||
"""
|
||||
|
||||
path: Path = Field(description="Path to the file")
|
||||
is_binary: bool = Field(
|
||||
default=False,
|
||||
description="Whether to read the file as binary data",
|
||||
)
|
||||
mime_type: str = Field(
|
||||
default="text/plain",
|
||||
description="MIME type of the resource content",
|
||||
)
|
||||
|
||||
@pydantic.field_validator("path")
|
||||
@classmethod
|
||||
def validate_absolute_path(cls, path: Path) -> Path:
|
||||
"""Ensure path is absolute."""
|
||||
if not path.is_absolute():
|
||||
raise ValueError("Path must be absolute")
|
||||
return path
|
||||
|
||||
@pydantic.field_validator("is_binary")
|
||||
@classmethod
|
||||
def set_binary_from_mime_type(cls, is_binary: bool, info: ValidationInfo) -> bool:
|
||||
"""Set is_binary based on mime_type if not explicitly set."""
|
||||
if is_binary:
|
||||
return True
|
||||
mime_type = info.data.get("mime_type", "text/plain")
|
||||
return not mime_type.startswith("text/")
|
||||
|
||||
async def read(self) -> str | bytes:
|
||||
"""Read the file content."""
|
||||
try:
|
||||
if self.is_binary:
|
||||
return await anyio.to_thread.run_sync(self.path.read_bytes)
|
||||
return await anyio.to_thread.run_sync(self.path.read_text)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading file {self.path}: {e}")
|
||||
|
||||
|
||||
class HttpResource(Resource):
|
||||
"""A resource that reads from an HTTP endpoint."""
|
||||
|
||||
url: str = Field(description="URL to fetch content from")
|
||||
mime_type: str = Field(
|
||||
default="application/json", description="MIME type of the resource content"
|
||||
)
|
||||
|
||||
async def read(self) -> str | bytes:
|
||||
"""Read the HTTP content."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(self.url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
|
||||
class DirectoryResource(Resource):
|
||||
"""A resource that lists files in a directory."""
|
||||
|
||||
path: Path = Field(description="Path to the directory")
|
||||
recursive: bool = Field(
|
||||
default=False, description="Whether to list files recursively"
|
||||
)
|
||||
pattern: str | None = Field(
|
||||
default=None, description="Optional glob pattern to filter files"
|
||||
)
|
||||
mime_type: str = Field(
|
||||
default="application/json", description="MIME type of the resource content"
|
||||
)
|
||||
|
||||
@pydantic.field_validator("path")
|
||||
@classmethod
|
||||
def validate_absolute_path(cls, path: Path) -> Path:
|
||||
"""Ensure path is absolute."""
|
||||
if not path.is_absolute():
|
||||
raise ValueError("Path must be absolute")
|
||||
return path
|
||||
|
||||
def list_files(self) -> list[Path]:
|
||||
"""List files in the directory."""
|
||||
if not self.path.exists():
|
||||
raise FileNotFoundError(f"Directory not found: {self.path}")
|
||||
if not self.path.is_dir():
|
||||
raise NotADirectoryError(f"Not a directory: {self.path}")
|
||||
|
||||
try:
|
||||
if self.pattern:
|
||||
return (
|
||||
list(self.path.glob(self.pattern))
|
||||
if not self.recursive
|
||||
else list(self.path.rglob(self.pattern))
|
||||
)
|
||||
return (
|
||||
list(self.path.glob("*"))
|
||||
if not self.recursive
|
||||
else list(self.path.rglob("*"))
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error listing directory {self.path}: {e}")
|
||||
|
||||
async def read(self) -> str: # Always returns JSON string
|
||||
"""Read the directory listing."""
|
||||
try:
|
||||
files = await anyio.to_thread.run_sync(self.list_files)
|
||||
file_list = [str(f.relative_to(self.path)) for f in files if f.is_file()]
|
||||
return json.dumps({"files": file_list}, indent=2)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading directory {self.path}: {e}")
|
||||
668
src/mcp/server/fastmcp/server.py
Normal file
668
src/mcp/server/fastmcp/server.py
Normal file
@@ -0,0 +1,668 @@
|
||||
"""FastMCP - A more ergonomic interface for MCP servers."""
|
||||
|
||||
import anyio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Literal, Sequence
|
||||
from collections.abc import Iterable
|
||||
|
||||
import pydantic_core
|
||||
from pydantic import Field
|
||||
import uvicorn
|
||||
from mcp.server.lowlevel import Server as MCPServer
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.types import (
|
||||
EmbeddedResource,
|
||||
GetPromptResult,
|
||||
ImageContent,
|
||||
TextContent,
|
||||
)
|
||||
from mcp.types import (
|
||||
Prompt as MCPPrompt,
|
||||
PromptArgument as MCPPromptArgument,
|
||||
)
|
||||
from mcp.types import (
|
||||
Resource as MCPResource,
|
||||
)
|
||||
from mcp.types import (
|
||||
ResourceTemplate as MCPResourceTemplate,
|
||||
)
|
||||
from mcp.types import (
|
||||
Tool as MCPTool,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from pydantic.networks import AnyUrl
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from mcp.server.fastmcp.exceptions import ResourceError
|
||||
from mcp.server.fastmcp.prompts import Prompt, PromptManager
|
||||
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
|
||||
from mcp.server.fastmcp.tools import ToolManager
|
||||
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
|
||||
from mcp.server.fastmcp.utilities.types import Image
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""FastMCP server settings.
|
||||
|
||||
All settings can be configured via environment variables with the prefix FASTMCP_.
|
||||
For example, FASTMCP_DEBUG=true will set debug=True.
|
||||
"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="FASTMCP_",
|
||||
env_file=".env",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# Server settings
|
||||
debug: bool = False
|
||||
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
||||
|
||||
# HTTP settings
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
# resource settings
|
||||
warn_on_duplicate_resources: bool = True
|
||||
|
||||
# tool settings
|
||||
warn_on_duplicate_tools: bool = True
|
||||
|
||||
# prompt settings
|
||||
warn_on_duplicate_prompts: bool = True
|
||||
|
||||
dependencies: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of dependencies to install in the server environment",
|
||||
)
|
||||
|
||||
|
||||
class FastMCP:
|
||||
def __init__(self, name: str | None = None, **settings: Any):
|
||||
self.settings = Settings(**settings)
|
||||
self._mcp_server = MCPServer(name=name or "FastMCP")
|
||||
self._tool_manager = ToolManager(
|
||||
warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
|
||||
)
|
||||
self._resource_manager = ResourceManager(
|
||||
warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources
|
||||
)
|
||||
self._prompt_manager = PromptManager(
|
||||
warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts
|
||||
)
|
||||
self.dependencies = self.settings.dependencies
|
||||
|
||||
# Set up MCP protocol handlers
|
||||
self._setup_handlers()
|
||||
|
||||
# Configure logging
|
||||
configure_logging(self.settings.log_level)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._mcp_server.name
|
||||
|
||||
def run(self, transport: Literal["stdio", "sse"] = "stdio") -> None:
|
||||
"""Run the FastMCP server. Note this is a synchronous function.
|
||||
|
||||
Args:
|
||||
transport: Transport protocol to use ("stdio" or "sse")
|
||||
"""
|
||||
TRANSPORTS = Literal["stdio", "sse"]
|
||||
if transport not in TRANSPORTS.__args__: # type: ignore
|
||||
raise ValueError(f"Unknown transport: {transport}")
|
||||
|
||||
if transport == "stdio":
|
||||
anyio.run(self.run_stdio_async)
|
||||
else: # transport == "sse"
|
||||
anyio.run(self.run_sse_async)
|
||||
|
||||
def _setup_handlers(self) -> None:
|
||||
"""Set up core MCP protocol handlers."""
|
||||
self._mcp_server.list_tools()(self.list_tools)
|
||||
self._mcp_server.call_tool()(self.call_tool)
|
||||
self._mcp_server.list_resources()(self.list_resources)
|
||||
self._mcp_server.read_resource()(self.read_resource)
|
||||
self._mcp_server.list_prompts()(self.list_prompts)
|
||||
self._mcp_server.get_prompt()(self.get_prompt)
|
||||
# TODO: This has not been added to MCP yet, see https://github.com/jlowin/fastmcp/issues/10
|
||||
# self._mcp_server.list_resource_templates()(self.list_resource_templates)
|
||||
|
||||
async def list_tools(self) -> list[MCPTool]:
|
||||
"""List all available tools."""
|
||||
tools = self._tool_manager.list_tools()
|
||||
return [
|
||||
MCPTool(
|
||||
name=info.name,
|
||||
description=info.description,
|
||||
inputSchema=info.parameters,
|
||||
)
|
||||
for info in tools
|
||||
]
|
||||
|
||||
def get_context(self) -> "Context":
|
||||
"""
|
||||
Returns a Context object. Note that the context will only be valid
|
||||
during a request; outside a request, most methods will error.
|
||||
"""
|
||||
try:
|
||||
request_context = self._mcp_server.request_context
|
||||
except LookupError:
|
||||
request_context = None
|
||||
return Context(request_context=request_context, fastmcp=self)
|
||||
|
||||
async def call_tool(
|
||||
self, name: str, arguments: dict
|
||||
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
|
||||
"""Call a tool by name with arguments."""
|
||||
context = self.get_context()
|
||||
result = await self._tool_manager.call_tool(name, arguments, context=context)
|
||||
converted_result = _convert_to_content(result)
|
||||
return converted_result
|
||||
|
||||
async def list_resources(self) -> list[MCPResource]:
|
||||
"""List all available resources."""
|
||||
|
||||
resources = self._resource_manager.list_resources()
|
||||
return [
|
||||
MCPResource(
|
||||
uri=resource.uri,
|
||||
name=resource.name or "",
|
||||
description=resource.description,
|
||||
mimeType=resource.mime_type,
|
||||
)
|
||||
for resource in resources
|
||||
]
|
||||
|
||||
async def list_resource_templates(self) -> list[MCPResourceTemplate]:
|
||||
templates = self._resource_manager.list_templates()
|
||||
return [
|
||||
MCPResourceTemplate(
|
||||
uriTemplate=template.uri_template,
|
||||
name=template.name,
|
||||
description=template.description,
|
||||
)
|
||||
for template in templates
|
||||
]
|
||||
|
||||
async def read_resource(self, uri: AnyUrl | str) -> str | bytes:
|
||||
"""Read a resource by URI."""
|
||||
resource = await self._resource_manager.get_resource(uri)
|
||||
if not resource:
|
||||
raise ResourceError(f"Unknown resource: {uri}")
|
||||
|
||||
try:
|
||||
return await resource.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading resource {uri}: {e}")
|
||||
raise ResourceError(str(e))
|
||||
|
||||
def add_tool(
|
||||
self,
|
||||
fn: Callable,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> None:
|
||||
"""Add a tool to the server.
|
||||
|
||||
The tool function can optionally request a Context object by adding a parameter
|
||||
with the Context type annotation. See the @tool decorator for examples.
|
||||
|
||||
Args:
|
||||
fn: The function to register as a tool
|
||||
name: Optional name for the tool (defaults to function name)
|
||||
description: Optional description of what the tool does
|
||||
"""
|
||||
self._tool_manager.add_tool(fn, name=name, description=description)
|
||||
|
||||
def tool(self, name: str | None = None, description: str | None = None) -> Callable:
|
||||
"""Decorator to register a tool.
|
||||
|
||||
Tools can optionally request a Context object by adding a parameter with the Context type annotation.
|
||||
The context provides access to MCP capabilities like logging, progress reporting, and resource access.
|
||||
|
||||
Args:
|
||||
name: Optional name for the tool (defaults to function name)
|
||||
description: Optional description of what the tool does
|
||||
|
||||
Example:
|
||||
@server.tool()
|
||||
def my_tool(x: int) -> str:
|
||||
return str(x)
|
||||
|
||||
@server.tool()
|
||||
def tool_with_context(x: int, ctx: Context) -> str:
|
||||
ctx.info(f"Processing {x}")
|
||||
return str(x)
|
||||
|
||||
@server.tool()
|
||||
async def async_tool(x: int, context: Context) -> str:
|
||||
await context.report_progress(50, 100)
|
||||
return str(x)
|
||||
"""
|
||||
# Check if user passed function directly instead of calling decorator
|
||||
if callable(name):
|
||||
raise TypeError(
|
||||
"The @tool decorator was used incorrectly. "
|
||||
"Did you forget to call it? Use @tool() instead of @tool"
|
||||
)
|
||||
|
||||
def decorator(fn: Callable) -> Callable:
|
||||
self.add_tool(fn, name=name, description=description)
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
def add_resource(self, resource: Resource) -> None:
|
||||
"""Add a resource to the server.
|
||||
|
||||
Args:
|
||||
resource: A Resource instance to add
|
||||
"""
|
||||
self._resource_manager.add_resource(resource)
|
||||
|
||||
def resource(
|
||||
self,
|
||||
uri: str,
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> Callable:
|
||||
"""Decorator to register a function as a resource.
|
||||
|
||||
The function will be called when the resource is read to generate its content.
|
||||
The function can return:
|
||||
- str for text content
|
||||
- bytes for binary content
|
||||
- other types will be converted to JSON
|
||||
|
||||
If the URI contains parameters (e.g. "resource://{param}") or the function
|
||||
has parameters, it will be registered as a template resource.
|
||||
|
||||
Args:
|
||||
uri: URI for the resource (e.g. "resource://my-resource" or "resource://{param}")
|
||||
name: Optional name for the resource
|
||||
description: Optional description of the resource
|
||||
mime_type: Optional MIME type for the resource
|
||||
|
||||
Example:
|
||||
@server.resource("resource://my-resource")
|
||||
def get_data() -> str:
|
||||
return "Hello, world!"
|
||||
|
||||
@server.resource("resource://{city}/weather")
|
||||
def get_weather(city: str) -> str:
|
||||
return f"Weather for {city}"
|
||||
"""
|
||||
# Check if user passed function directly instead of calling decorator
|
||||
if callable(uri):
|
||||
raise TypeError(
|
||||
"The @resource decorator was used incorrectly. "
|
||||
"Did you forget to call it? Use @resource('uri') instead of @resource"
|
||||
)
|
||||
|
||||
def decorator(fn: Callable) -> Callable:
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
# Check if this should be a template
|
||||
has_uri_params = "{" in uri and "}" in uri
|
||||
has_func_params = bool(inspect.signature(fn).parameters)
|
||||
|
||||
if has_uri_params or has_func_params:
|
||||
# Validate that URI params match function params
|
||||
uri_params = set(re.findall(r"{(\w+)}", uri))
|
||||
func_params = set(inspect.signature(fn).parameters.keys())
|
||||
|
||||
if uri_params != func_params:
|
||||
raise ValueError(
|
||||
f"Mismatch between URI parameters {uri_params} "
|
||||
f"and function parameters {func_params}"
|
||||
)
|
||||
|
||||
# Register as template
|
||||
self._resource_manager.add_template(
|
||||
wrapper,
|
||||
uri_template=uri,
|
||||
name=name,
|
||||
description=description,
|
||||
mime_type=mime_type or "text/plain",
|
||||
)
|
||||
else:
|
||||
# Register as regular resource
|
||||
resource = FunctionResource(
|
||||
uri=AnyUrl(uri),
|
||||
name=name,
|
||||
description=description,
|
||||
mime_type=mime_type or "text/plain",
|
||||
fn=wrapper,
|
||||
)
|
||||
self.add_resource(resource)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
def add_prompt(self, prompt: Prompt) -> None:
|
||||
"""Add a prompt to the server.
|
||||
|
||||
Args:
|
||||
prompt: A Prompt instance to add
|
||||
"""
|
||||
self._prompt_manager.add_prompt(prompt)
|
||||
|
||||
def prompt(
|
||||
self, name: str | None = None, description: str | None = None
|
||||
) -> Callable:
|
||||
"""Decorator to register a prompt.
|
||||
|
||||
Args:
|
||||
name: Optional name for the prompt (defaults to function name)
|
||||
description: Optional description of what the prompt does
|
||||
|
||||
Example:
|
||||
@server.prompt()
|
||||
def analyze_table(table_name: str) -> list[Message]:
|
||||
schema = read_table_schema(table_name)
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Analyze this schema:\n{schema}"
|
||||
}
|
||||
]
|
||||
|
||||
@server.prompt()
|
||||
async def analyze_file(path: str) -> list[Message]:
|
||||
content = await read_file(path)
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": {
|
||||
"type": "resource",
|
||||
"resource": {
|
||||
"uri": f"file://{path}",
|
||||
"text": content
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
# Check if user passed function directly instead of calling decorator
|
||||
if callable(name):
|
||||
raise TypeError(
|
||||
"The @prompt decorator was used incorrectly. "
|
||||
"Did you forget to call it? Use @prompt() instead of @prompt"
|
||||
)
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
prompt = Prompt.from_function(func, name=name, description=description)
|
||||
self.add_prompt(prompt)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def run_stdio_async(self) -> None:
|
||||
"""Run the server using stdio transport."""
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
await self._mcp_server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
self._mcp_server.create_initialization_options(),
|
||||
)
|
||||
|
||||
async def run_sse_async(self) -> None:
|
||||
"""Run the server using SSE transport."""
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Route
|
||||
|
||||
sse = SseServerTransport("/messages")
|
||||
|
||||
async def handle_sse(request):
|
||||
async with sse.connect_sse(
|
||||
request.scope, request.receive, request._send
|
||||
) as streams:
|
||||
await self._mcp_server.run(
|
||||
streams[0],
|
||||
streams[1],
|
||||
self._mcp_server.create_initialization_options(),
|
||||
)
|
||||
|
||||
async def handle_messages(request):
|
||||
await sse.handle_post_message(request.scope, request.receive, request._send)
|
||||
|
||||
starlette_app = Starlette(
|
||||
debug=self.settings.debug,
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Route("/messages", endpoint=handle_messages, methods=["POST"]),
|
||||
],
|
||||
)
|
||||
|
||||
config = uvicorn.Config(
|
||||
starlette_app,
|
||||
host=self.settings.host,
|
||||
port=self.settings.port,
|
||||
log_level=self.settings.log_level.lower(),
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
async def list_prompts(self) -> list[MCPPrompt]:
|
||||
"""List all available prompts."""
|
||||
prompts = self._prompt_manager.list_prompts()
|
||||
return [
|
||||
MCPPrompt(
|
||||
name=prompt.name,
|
||||
description=prompt.description,
|
||||
arguments=[
|
||||
MCPPromptArgument(
|
||||
name=arg.name,
|
||||
description=arg.description,
|
||||
required=arg.required,
|
||||
)
|
||||
for arg in (prompt.arguments or [])
|
||||
],
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
|
||||
async def get_prompt(
|
||||
self, name: str, arguments: dict[str, Any] | None = None
|
||||
) -> GetPromptResult:
|
||||
"""Get a prompt by name with arguments."""
|
||||
try:
|
||||
messages = await self._prompt_manager.render_prompt(name, arguments)
|
||||
|
||||
return GetPromptResult(messages=pydantic_core.to_jsonable_python(messages))
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting prompt {name}: {e}")
|
||||
raise ValueError(str(e))
|
||||
|
||||
|
||||
def _convert_to_content(
|
||||
result: Any,
|
||||
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
|
||||
"""Convert a result to a sequence of content objects."""
|
||||
if result is None:
|
||||
return []
|
||||
|
||||
if isinstance(result, (TextContent, ImageContent, EmbeddedResource)):
|
||||
return [result]
|
||||
|
||||
if isinstance(result, Image):
|
||||
return [result.to_image_content()]
|
||||
|
||||
if isinstance(result, (list, tuple)):
|
||||
return list(chain.from_iterable(_convert_to_content(item) for item in result))
|
||||
|
||||
if not isinstance(result, str):
|
||||
try:
|
||||
result = json.dumps(pydantic_core.to_jsonable_python(result))
|
||||
except Exception:
|
||||
result = str(result)
|
||||
|
||||
return [TextContent(type="text", text=result)]
|
||||
|
||||
|
||||
class Context(BaseModel):
|
||||
"""Context object providing access to MCP capabilities.
|
||||
|
||||
This provides a cleaner interface to MCP's RequestContext functionality.
|
||||
It gets injected into tool and resource functions that request it via type hints.
|
||||
|
||||
To use context in a tool function, add a parameter with the Context type annotation:
|
||||
|
||||
```python
|
||||
@server.tool()
|
||||
def my_tool(x: int, ctx: Context) -> str:
|
||||
# Log messages to the client
|
||||
ctx.info(f"Processing {x}")
|
||||
ctx.debug("Debug info")
|
||||
ctx.warning("Warning message")
|
||||
ctx.error("Error message")
|
||||
|
||||
# Report progress
|
||||
ctx.report_progress(50, 100)
|
||||
|
||||
# Access resources
|
||||
data = ctx.read_resource("resource://data")
|
||||
|
||||
# Get request info
|
||||
request_id = ctx.request_id
|
||||
client_id = ctx.client_id
|
||||
|
||||
return str(x)
|
||||
```
|
||||
|
||||
The context parameter name can be anything as long as it's annotated with Context.
|
||||
The context is optional - tools that don't need it can omit the parameter.
|
||||
"""
|
||||
|
||||
_request_context: RequestContext | None
|
||||
_fastmcp: FastMCP | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_context: RequestContext | None = None,
|
||||
fastmcp: FastMCP | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._request_context = request_context
|
||||
self._fastmcp = fastmcp
|
||||
|
||||
@property
|
||||
def fastmcp(self) -> FastMCP:
|
||||
"""Access to the FastMCP server."""
|
||||
if self._fastmcp is None:
|
||||
raise ValueError("Context is not available outside of a request")
|
||||
return self._fastmcp
|
||||
|
||||
@property
|
||||
def request_context(self) -> RequestContext:
|
||||
"""Access to the underlying request context."""
|
||||
if self._request_context is None:
|
||||
raise ValueError("Context is not available outside of a request")
|
||||
return self._request_context
|
||||
|
||||
async def report_progress(
|
||||
self, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""Report progress for the current operation.
|
||||
|
||||
Args:
|
||||
progress: Current progress value e.g. 24
|
||||
total: Optional total value e.g. 100
|
||||
"""
|
||||
|
||||
progress_token = (
|
||||
self.request_context.meta.progressToken
|
||||
if self.request_context.meta
|
||||
else None
|
||||
)
|
||||
|
||||
if not progress_token:
|
||||
return
|
||||
|
||||
await self.request_context.session.send_progress_notification(
|
||||
progress_token=progress_token, progress=progress, total=total
|
||||
)
|
||||
|
||||
async def read_resource(self, uri: str | AnyUrl) -> str | bytes:
|
||||
"""Read a resource by URI.
|
||||
|
||||
Args:
|
||||
uri: Resource URI to read
|
||||
|
||||
Returns:
|
||||
The resource content as either text or bytes
|
||||
"""
|
||||
assert (
|
||||
self._fastmcp is not None
|
||||
), "Context is not available outside of a request"
|
||||
return await self._fastmcp.read_resource(uri)
|
||||
|
||||
def log(
|
||||
self,
|
||||
level: Literal["debug", "info", "warning", "error"],
|
||||
message: str,
|
||||
*,
|
||||
logger_name: str | None = None,
|
||||
) -> None:
|
||||
"""Send a log message to the client.
|
||||
|
||||
Args:
|
||||
level: Log level (debug, info, warning, error)
|
||||
message: Log message
|
||||
logger_name: Optional logger name
|
||||
**extra: Additional structured data to include
|
||||
"""
|
||||
self.request_context.session.send_log_message(
|
||||
level=level, data=message, logger=logger_name
|
||||
)
|
||||
|
||||
@property
|
||||
def client_id(self) -> str | None:
|
||||
"""Get the client ID if available."""
|
||||
return (
|
||||
getattr(self.request_context.meta, "client_id", None)
|
||||
if self.request_context.meta
|
||||
else None
|
||||
)
|
||||
|
||||
@property
|
||||
def request_id(self) -> str:
|
||||
"""Get the unique ID for this request."""
|
||||
return str(self.request_context.request_id)
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
"""Access to the underlying session for advanced usage."""
|
||||
return self.request_context.session
|
||||
|
||||
# Convenience methods for common log levels
|
||||
def debug(self, message: str, **extra: Any) -> None:
|
||||
"""Send a debug log message."""
|
||||
self.log("debug", message, **extra)
|
||||
|
||||
def info(self, message: str, **extra: Any) -> None:
|
||||
"""Send an info log message."""
|
||||
self.log("info", message, **extra)
|
||||
|
||||
def warning(self, message: str, **extra: Any) -> None:
|
||||
"""Send a warning log message."""
|
||||
self.log("warning", message, **extra)
|
||||
|
||||
def error(self, message: str, **extra: Any) -> None:
|
||||
"""Send an error log message."""
|
||||
self.log("error", message, **extra)
|
||||
4
src/mcp/server/fastmcp/tools/__init__.py
Normal file
4
src/mcp/server/fastmcp/tools/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import Tool
|
||||
from .tool_manager import ToolManager
|
||||
|
||||
__all__ = ["Tool", "ToolManager"]
|
||||
82
src/mcp/server/fastmcp/tools/base.py
Normal file
82
src/mcp/server/fastmcp/tools/base.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import mcp.server.fastmcp
|
||||
from mcp.server.fastmcp.exceptions import ToolError
|
||||
from mcp.server.fastmcp.utilities.func_metadata import func_metadata, FuncMetadata
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server.fastmcp.server import Context
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
"""Internal tool registration info."""
|
||||
|
||||
fn: Callable = Field(exclude=True)
|
||||
name: str = Field(description="Name of the tool")
|
||||
description: str = Field(description="Description of what the tool does")
|
||||
parameters: dict = Field(description="JSON schema for tool parameters")
|
||||
fn_metadata: FuncMetadata = Field(
|
||||
description="Metadata about the function including a pydantic model for tool arguments"
|
||||
)
|
||||
is_async: bool = Field(description="Whether the tool is async")
|
||||
context_kwarg: str | None = Field(
|
||||
None, description="Name of the kwarg that should receive context"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
fn: Callable,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
context_kwarg: str | None = None,
|
||||
) -> "Tool":
|
||||
"""Create a Tool from a function."""
|
||||
func_name = name or fn.__name__
|
||||
|
||||
if func_name == "<lambda>":
|
||||
raise ValueError("You must provide a name for lambda functions")
|
||||
|
||||
func_doc = description or fn.__doc__ or ""
|
||||
is_async = inspect.iscoroutinefunction(fn)
|
||||
|
||||
# Find context parameter if it exists
|
||||
if context_kwarg is None:
|
||||
sig = inspect.signature(fn)
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.annotation is mcp.server.fastmcp.Context:
|
||||
context_kwarg = param_name
|
||||
break
|
||||
|
||||
func_arg_metadata = func_metadata(
|
||||
fn,
|
||||
skip_names=[context_kwarg] if context_kwarg is not None else [],
|
||||
)
|
||||
parameters = func_arg_metadata.arg_model.model_json_schema()
|
||||
|
||||
return cls(
|
||||
fn=fn,
|
||||
name=func_name,
|
||||
description=func_doc,
|
||||
parameters=parameters,
|
||||
fn_metadata=func_arg_metadata,
|
||||
is_async=is_async,
|
||||
context_kwarg=context_kwarg,
|
||||
)
|
||||
|
||||
async def run(self, arguments: dict, context: "Context | None" = None) -> Any:
|
||||
"""Run the tool with arguments."""
|
||||
try:
|
||||
return await self.fn_metadata.call_fn_with_arg_validation(
|
||||
self.fn,
|
||||
self.is_async,
|
||||
arguments,
|
||||
{self.context_kwarg: context}
|
||||
if self.context_kwarg is not None
|
||||
else None,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Error executing tool {self.name}: {e}") from e
|
||||
54
src/mcp/server/fastmcp/tools/tool_manager.py
Normal file
54
src/mcp/server/fastmcp/tools/tool_manager.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from mcp.server.fastmcp.exceptions import ToolError
|
||||
from mcp.server.fastmcp.tools.base import Tool
|
||||
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
from collections.abc import Callable
|
||||
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server.fastmcp.server import Context
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""Manages FastMCP tools."""
|
||||
|
||||
def __init__(self, warn_on_duplicate_tools: bool = True):
|
||||
self._tools: dict[str, Tool] = {}
|
||||
self.warn_on_duplicate_tools = warn_on_duplicate_tools
|
||||
|
||||
def get_tool(self, name: str) -> Tool | None:
|
||||
"""Get tool by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""List all registered tools."""
|
||||
return list(self._tools.values())
|
||||
|
||||
def add_tool(
|
||||
self,
|
||||
fn: Callable,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> Tool:
|
||||
"""Add a tool to the server."""
|
||||
tool = Tool.from_function(fn, name=name, description=description)
|
||||
existing = self._tools.get(tool.name)
|
||||
if existing:
|
||||
if self.warn_on_duplicate_tools:
|
||||
logger.warning(f"Tool already exists: {tool.name}")
|
||||
return existing
|
||||
self._tools[tool.name] = tool
|
||||
return tool
|
||||
|
||||
async def call_tool(
|
||||
self, name: str, arguments: dict, context: "Context | None" = None
|
||||
) -> Any:
|
||||
"""Call a tool by name with arguments."""
|
||||
tool = self.get_tool(name)
|
||||
if not tool:
|
||||
raise ToolError(f"Unknown tool: {name}")
|
||||
|
||||
return await tool.run(arguments, context=context)
|
||||
1
src/mcp/server/fastmcp/utilities/__init__.py
Normal file
1
src/mcp/server/fastmcp/utilities/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""FastMCP utility modules."""
|
||||
210
src/mcp/server/fastmcp/utilities/func_metadata.py
Normal file
210
src/mcp/server/fastmcp/utilities/func_metadata.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import inspect
|
||||
from collections.abc import Callable, Sequence, Awaitable
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
ForwardRef,
|
||||
)
|
||||
from pydantic import Field
|
||||
from mcp.server.fastmcp.exceptions import InvalidSignature
|
||||
from pydantic._internal._typing_extra import eval_type_backport
|
||||
import json
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic import ConfigDict, create_model
|
||||
from pydantic import WithJsonSchema
|
||||
from pydantic_core import PydanticUndefined
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ArgModelBase(BaseModel):
|
||||
"""A model representing the arguments to a function."""
|
||||
|
||||
def model_dump_one_level(self) -> dict[str, Any]:
|
||||
"""Return a dict of the model's fields, one level deep.
|
||||
|
||||
That is, sub-models etc are not dumped - they are kept as pydantic models.
|
||||
"""
|
||||
kwargs: dict[str, Any] = {}
|
||||
for field_name in self.model_fields.keys():
|
||||
kwargs[field_name] = getattr(self, field_name)
|
||||
return kwargs
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
|
||||
class FuncMetadata(BaseModel):
|
||||
arg_model: Annotated[type[ArgModelBase], WithJsonSchema(None)]
|
||||
# We can add things in the future like
|
||||
# - Maybe some args are excluded from attempting to parse from JSON
|
||||
# - Maybe some args are special (like context) for dependency injection
|
||||
|
||||
async def call_fn_with_arg_validation(
|
||||
self,
|
||||
fn: Callable[..., Any] | Awaitable[Any],
|
||||
fn_is_async: bool,
|
||||
arguments_to_validate: dict[str, Any],
|
||||
arguments_to_pass_directly: dict[str, Any] | None,
|
||||
) -> Any:
|
||||
"""Call the given function with arguments validated and injected.
|
||||
|
||||
Arguments are first attempted to be parsed from JSON, then validated against
|
||||
the argument model, before being passed to the function.
|
||||
"""
|
||||
arguments_pre_parsed = self.pre_parse_json(arguments_to_validate)
|
||||
arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed)
|
||||
arguments_parsed_dict = arguments_parsed_model.model_dump_one_level()
|
||||
|
||||
arguments_parsed_dict |= arguments_to_pass_directly or {}
|
||||
|
||||
if fn_is_async:
|
||||
if isinstance(fn, Awaitable):
|
||||
return await fn
|
||||
return await fn(**arguments_parsed_dict)
|
||||
if isinstance(fn, Callable):
|
||||
return fn(**arguments_parsed_dict)
|
||||
raise TypeError("fn must be either Callable or Awaitable")
|
||||
|
||||
def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Pre-parse data from JSON.
|
||||
|
||||
Return a dict with same keys as input but with values parsed from JSON
|
||||
if appropriate.
|
||||
|
||||
This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside
|
||||
a string rather than an actual list. Claude desktop is prone to this - in fact
|
||||
it seems incapable of NOT doing this. For sub-models, it tends to pass
|
||||
dicts (JSON objects) as JSON strings, which can be pre-parsed here.
|
||||
"""
|
||||
new_data = data.copy() # Shallow copy
|
||||
for field_name, field_info in self.arg_model.model_fields.items():
|
||||
if field_name not in data.keys():
|
||||
continue
|
||||
if isinstance(data[field_name], str):
|
||||
try:
|
||||
pre_parsed = json.loads(data[field_name])
|
||||
except json.JSONDecodeError:
|
||||
continue # Not JSON - skip
|
||||
if isinstance(pre_parsed, str):
|
||||
# This is likely that the raw value is e.g. `"hello"` which we
|
||||
# Should really be parsed as '"hello"' in Python - but if we parse
|
||||
# it as JSON it'll turn into just 'hello'. So we skip it.
|
||||
continue
|
||||
new_data[field_name] = pre_parsed
|
||||
assert new_data.keys() == data.keys()
|
||||
return new_data
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
|
||||
def func_metadata(func: Callable, skip_names: Sequence[str] = ()) -> FuncMetadata:
|
||||
"""Given a function, return metadata including a pydantic model representing its signature.
|
||||
|
||||
The use case for this is
|
||||
```
|
||||
meta = func_to_pyd(func)
|
||||
validated_args = meta.arg_model.model_validate(some_raw_data_dict)
|
||||
return func(**validated_args.model_dump_one_level())
|
||||
```
|
||||
|
||||
**critically** it also provides pre-parse helper to attempt to parse things from JSON.
|
||||
|
||||
Args:
|
||||
func: The function to convert to a pydantic model
|
||||
skip_names: A list of parameter names to skip. These will not be included in
|
||||
the model.
|
||||
Returns:
|
||||
A pydantic model representing the function's signature.
|
||||
"""
|
||||
sig = _get_typed_signature(func)
|
||||
params = sig.parameters
|
||||
dynamic_pydantic_model_params: dict[str, Any] = {}
|
||||
globalns = getattr(func, "__globals__", {})
|
||||
for param in params.values():
|
||||
if param.name.startswith("_"):
|
||||
raise InvalidSignature(
|
||||
f"Parameter {param.name} of {func.__name__} may not start with an underscore"
|
||||
)
|
||||
if param.name in skip_names:
|
||||
continue
|
||||
annotation = param.annotation
|
||||
|
||||
# `x: None` / `x: None = None`
|
||||
if annotation is None:
|
||||
annotation = Annotated[
|
||||
None,
|
||||
Field(
|
||||
default=param.default
|
||||
if param.default is not inspect.Parameter.empty
|
||||
else PydanticUndefined
|
||||
),
|
||||
]
|
||||
|
||||
# Untyped field
|
||||
if annotation is inspect.Parameter.empty:
|
||||
annotation = Annotated[
|
||||
Any,
|
||||
Field(),
|
||||
# 🤷
|
||||
WithJsonSchema({"title": param.name, "type": "string"}),
|
||||
]
|
||||
|
||||
field_info = FieldInfo.from_annotated_attribute(
|
||||
_get_typed_annotation(annotation, globalns),
|
||||
param.default
|
||||
if param.default is not inspect.Parameter.empty
|
||||
else PydanticUndefined,
|
||||
)
|
||||
dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info)
|
||||
continue
|
||||
|
||||
arguments_model = create_model(
|
||||
f"{func.__name__}Arguments",
|
||||
**dynamic_pydantic_model_params,
|
||||
__base__=ArgModelBase,
|
||||
)
|
||||
resp = FuncMetadata(arg_model=arguments_model)
|
||||
return resp
|
||||
|
||||
|
||||
def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
|
||||
def try_eval_type(value, globalns, localns):
|
||||
try:
|
||||
return eval_type_backport(value, globalns, localns), True
|
||||
except NameError:
|
||||
return value, False
|
||||
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
annotation, status = try_eval_type(annotation, globalns, globalns)
|
||||
|
||||
# This check and raise could perhaps be skipped, and we (FastMCP) just call
|
||||
# model_rebuild right before using it 🤷
|
||||
if status is False:
|
||||
raise InvalidSignature(f"Unable to evaluate type annotation {annotation}")
|
||||
|
||||
return annotation
|
||||
|
||||
|
||||
def _get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
"""Get function signature while evaluating forward references"""
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=_get_typed_annotation(param.annotation, globalns),
|
||||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
typed_signature = inspect.Signature(typed_params)
|
||||
return typed_signature
|
||||
41
src/mcp/server/fastmcp/utilities/logging.py
Normal file
41
src/mcp/server/fastmcp/utilities/logging.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Logging utilities for FastMCP."""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger nested under MCPnamespace.
|
||||
|
||||
Args:
|
||||
name: the name of the logger, which will be prefixed with 'FastMCP.'
|
||||
|
||||
Returns:
|
||||
a configured logger instance
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def configure_logging(
|
||||
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO",
|
||||
) -> None:
|
||||
"""Configure logging for MCP.
|
||||
|
||||
Args:
|
||||
level: the log level to use
|
||||
"""
|
||||
handlers = []
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
handlers.append(RichHandler(console=Console(stderr=True), rich_tracebacks=True))
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if not handlers:
|
||||
handlers.append(logging.StreamHandler())
|
||||
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(message)s",
|
||||
handlers=handlers,
|
||||
)
|
||||
54
src/mcp/server/fastmcp/utilities/types.py
Normal file
54
src/mcp/server/fastmcp/utilities/types.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Common types used across FastMCP."""
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
from mcp.types import ImageContent
|
||||
|
||||
|
||||
class Image:
|
||||
"""Helper class for returning images from tools."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str | Path | None = None,
|
||||
data: bytes | None = None,
|
||||
format: str | None = None,
|
||||
):
|
||||
if path is None and data is None:
|
||||
raise ValueError("Either path or data must be provided")
|
||||
if path is not None and data is not None:
|
||||
raise ValueError("Only one of path or data can be provided")
|
||||
|
||||
self.path = Path(path) if path else None
|
||||
self.data = data
|
||||
self._format = format
|
||||
self._mime_type = self._get_mime_type()
|
||||
|
||||
def _get_mime_type(self) -> str:
|
||||
"""Get MIME type from format or guess from file extension."""
|
||||
if self._format:
|
||||
return f"image/{self._format.lower()}"
|
||||
|
||||
if self.path:
|
||||
suffix = self.path.suffix.lower()
|
||||
return {
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
}.get(suffix, "application/octet-stream")
|
||||
return "image/png" # default for raw binary data
|
||||
|
||||
def to_image_content(self) -> ImageContent:
|
||||
"""Convert to MCP ImageContent."""
|
||||
if self.path:
|
||||
with open(self.path, "rb") as f:
|
||||
data = base64.b64encode(f.read()).decode()
|
||||
elif self.data is not None:
|
||||
data = base64.b64encode(self.data).decode()
|
||||
else:
|
||||
raise ValueError("No image data available")
|
||||
|
||||
return ImageContent(type="image", data=data, mimeType=self._mime_type)
|
||||
3
src/mcp/server/lowlevel/__init__.py
Normal file
3
src/mcp/server/lowlevel/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .server import Server, NotificationOptions
|
||||
|
||||
__all__ = ["Server", "NotificationOptions"]
|
||||
500
src/mcp/server/lowlevel/server.py
Normal file
500
src/mcp/server/lowlevel/server.py
Normal file
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
MCP Server Module
|
||||
|
||||
This module provides a framework for creating an MCP (Model Context Protocol) server.
|
||||
It allows you to easily define and handle various types of requests and notifications
|
||||
in an asynchronous manner.
|
||||
|
||||
Usage:
|
||||
1. Create a Server instance:
|
||||
server = Server("your_server_name")
|
||||
|
||||
2. Define request handlers using decorators:
|
||||
@server.list_prompts()
|
||||
async def handle_list_prompts() -> list[types.Prompt]:
|
||||
# Implementation
|
||||
|
||||
@server.get_prompt()
|
||||
async def handle_get_prompt(
|
||||
name: str, arguments: dict[str, str] | None
|
||||
) -> types.GetPromptResult:
|
||||
# Implementation
|
||||
|
||||
@server.list_tools()
|
||||
async def handle_list_tools() -> list[types.Tool]:
|
||||
# Implementation
|
||||
|
||||
@server.call_tool()
|
||||
async def handle_call_tool(
|
||||
name: str, arguments: dict | None
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
# Implementation
|
||||
|
||||
@server.list_resource_templates()
|
||||
async def handle_list_resource_templates() -> list[types.ResourceTemplate]:
|
||||
# Implementation
|
||||
|
||||
3. Define notification handlers if needed:
|
||||
@server.progress_notification()
|
||||
async def handle_progress(
|
||||
progress_token: str | int, progress: float, total: float | None
|
||||
) -> None:
|
||||
# Implementation
|
||||
|
||||
4. Run the server:
|
||||
async def main():
|
||||
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
|
||||
await server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
InitializationOptions(
|
||||
server_name="your_server_name",
|
||||
server_version="your_version",
|
||||
capabilities=server.get_capabilities(
|
||||
notification_options=NotificationOptions(),
|
||||
experimental_capabilities={},
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
The Server class provides methods to register handlers for various MCP requests and
|
||||
notifications. It automatically manages the request context and handles incoming
|
||||
messages from the client.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, Sequence
|
||||
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.server.stdio import stdio_server as stdio_server
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.session import RequestResponder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
request_ctx: contextvars.ContextVar[RequestContext[ServerSession]] = (
|
||||
contextvars.ContextVar("request_ctx")
|
||||
)
|
||||
|
||||
|
||||
class NotificationOptions:
|
||||
def __init__(
|
||||
self,
|
||||
prompts_changed: bool = False,
|
||||
resources_changed: bool = False,
|
||||
tools_changed: bool = False,
|
||||
):
|
||||
self.prompts_changed = prompts_changed
|
||||
self.resources_changed = resources_changed
|
||||
self.tools_changed = tools_changed
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.request_handlers: dict[
|
||||
type, Callable[..., Awaitable[types.ServerResult]]
|
||||
] = {
|
||||
types.PingRequest: _ping_handler,
|
||||
}
|
||||
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
|
||||
self.notification_options = NotificationOptions()
|
||||
logger.debug(f"Initializing server '{name}'")
|
||||
|
||||
def create_initialization_options(
|
||||
self,
|
||||
notification_options: NotificationOptions | None = None,
|
||||
experimental_capabilities: dict[str, dict[str, Any]] | None = None,
|
||||
) -> InitializationOptions:
|
||||
"""Create initialization options from this server instance."""
|
||||
|
||||
def pkg_version(package: str) -> str:
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
v = version(package)
|
||||
if v is not None:
|
||||
return v
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "unknown"
|
||||
|
||||
return InitializationOptions(
|
||||
server_name=self.name,
|
||||
server_version=pkg_version("mcp"),
|
||||
capabilities=self.get_capabilities(
|
||||
notification_options or NotificationOptions(),
|
||||
experimental_capabilities or {},
|
||||
),
|
||||
)
|
||||
|
||||
def get_capabilities(
|
||||
self,
|
||||
notification_options: NotificationOptions,
|
||||
experimental_capabilities: dict[str, dict[str, Any]],
|
||||
) -> types.ServerCapabilities:
|
||||
"""Convert existing handlers to a ServerCapabilities object."""
|
||||
prompts_capability = None
|
||||
resources_capability = None
|
||||
tools_capability = None
|
||||
logging_capability = None
|
||||
|
||||
# Set prompt capabilities if handler exists
|
||||
if types.ListPromptsRequest in self.request_handlers:
|
||||
prompts_capability = types.PromptsCapability(
|
||||
listChanged=notification_options.prompts_changed
|
||||
)
|
||||
|
||||
# Set resource capabilities if handler exists
|
||||
if types.ListResourcesRequest in self.request_handlers:
|
||||
resources_capability = types.ResourcesCapability(
|
||||
subscribe=False, listChanged=notification_options.resources_changed
|
||||
)
|
||||
|
||||
# Set tool capabilities if handler exists
|
||||
if types.ListToolsRequest in self.request_handlers:
|
||||
tools_capability = types.ToolsCapability(
|
||||
listChanged=notification_options.tools_changed
|
||||
)
|
||||
|
||||
# Set logging capabilities if handler exists
|
||||
if types.SetLevelRequest in self.request_handlers:
|
||||
logging_capability = types.LoggingCapability()
|
||||
|
||||
return types.ServerCapabilities(
|
||||
prompts=prompts_capability,
|
||||
resources=resources_capability,
|
||||
tools=tools_capability,
|
||||
logging=logging_capability,
|
||||
experimental=experimental_capabilities,
|
||||
)
|
||||
|
||||
@property
|
||||
def request_context(self) -> RequestContext[ServerSession]:
|
||||
"""If called outside of a request context, this will raise a LookupError."""
|
||||
return request_ctx.get()
|
||||
|
||||
def list_prompts(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[types.Prompt]]]):
|
||||
logger.debug("Registering handler for PromptListRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
prompts = await func()
|
||||
return types.ServerResult(types.ListPromptsResult(prompts=prompts))
|
||||
|
||||
self.request_handlers[types.ListPromptsRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_prompt(self):
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[str, dict[str, str] | None], Awaitable[types.GetPromptResult]
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for GetPromptRequest")
|
||||
|
||||
async def handler(req: types.GetPromptRequest):
|
||||
prompt_get = await func(req.params.name, req.params.arguments)
|
||||
return types.ServerResult(prompt_get)
|
||||
|
||||
self.request_handlers[types.GetPromptRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def list_resources(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[types.Resource]]]):
|
||||
logger.debug("Registering handler for ListResourcesRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
resources = await func()
|
||||
return types.ServerResult(
|
||||
types.ListResourcesResult(resources=resources)
|
||||
)
|
||||
|
||||
self.request_handlers[types.ListResourcesRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def list_resource_templates(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]):
|
||||
logger.debug("Registering handler for ListResourceTemplatesRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
templates = await func()
|
||||
return types.ServerResult(
|
||||
types.ListResourceTemplatesResult(resourceTemplates=templates)
|
||||
)
|
||||
|
||||
self.request_handlers[types.ListResourceTemplatesRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def read_resource(self):
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]):
|
||||
logger.debug("Registering handler for ReadResourceRequest")
|
||||
|
||||
async def handler(req: types.ReadResourceRequest):
|
||||
result = await func(req.params.uri)
|
||||
match result:
|
||||
case str(s):
|
||||
content = types.TextResourceContents(
|
||||
uri=req.params.uri,
|
||||
text=s,
|
||||
mimeType="text/plain",
|
||||
)
|
||||
case bytes(b):
|
||||
import base64
|
||||
|
||||
content = types.BlobResourceContents(
|
||||
uri=req.params.uri,
|
||||
blob=base64.urlsafe_b64encode(b).decode(),
|
||||
mimeType="application/octet-stream",
|
||||
)
|
||||
|
||||
return types.ServerResult(
|
||||
types.ReadResourceResult(
|
||||
contents=[content],
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[types.ReadResourceRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def set_logging_level(self):
|
||||
def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]):
|
||||
logger.debug("Registering handler for SetLevelRequest")
|
||||
|
||||
async def handler(req: types.SetLevelRequest):
|
||||
await func(req.params.level)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
self.request_handlers[types.SetLevelRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def subscribe_resource(self):
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||
logger.debug("Registering handler for SubscribeRequest")
|
||||
|
||||
async def handler(req: types.SubscribeRequest):
|
||||
await func(req.params.uri)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
self.request_handlers[types.SubscribeRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def unsubscribe_resource(self):
|
||||
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
|
||||
logger.debug("Registering handler for UnsubscribeRequest")
|
||||
|
||||
async def handler(req: types.UnsubscribeRequest):
|
||||
await func(req.params.uri)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
self.request_handlers[types.UnsubscribeRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def list_tools(self):
|
||||
def decorator(func: Callable[[], Awaitable[list[types.Tool]]]):
|
||||
logger.debug("Registering handler for ListToolsRequest")
|
||||
|
||||
async def handler(_: Any):
|
||||
tools = await func()
|
||||
return types.ServerResult(types.ListToolsResult(tools=tools))
|
||||
|
||||
self.request_handlers[types.ListToolsRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def call_tool(self):
|
||||
def decorator(
|
||||
func: Callable[
|
||||
...,
|
||||
Awaitable[
|
||||
Sequence[
|
||||
types.TextContent | types.ImageContent | types.EmbeddedResource
|
||||
]
|
||||
],
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for CallToolRequest")
|
||||
|
||||
async def handler(req: types.CallToolRequest):
|
||||
try:
|
||||
results = await func(req.params.name, (req.params.arguments or {}))
|
||||
return types.ServerResult(
|
||||
types.CallToolResult(content=list(results), isError=False)
|
||||
)
|
||||
except Exception as e:
|
||||
return types.ServerResult(
|
||||
types.CallToolResult(
|
||||
content=[types.TextContent(type="text", text=str(e))],
|
||||
isError=True,
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[types.CallToolRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def progress_notification(self):
|
||||
def decorator(
|
||||
func: Callable[[str | int, float, float | None], Awaitable[None]],
|
||||
):
|
||||
logger.debug("Registering handler for ProgressNotification")
|
||||
|
||||
async def handler(req: types.ProgressNotification):
|
||||
await func(
|
||||
req.params.progressToken, req.params.progress, req.params.total
|
||||
)
|
||||
|
||||
self.notification_handlers[types.ProgressNotification] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def completion(self):
|
||||
"""Provides completions for prompts and resource templates"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[
|
||||
[
|
||||
types.PromptReference | types.ResourceReference,
|
||||
types.CompletionArgument,
|
||||
],
|
||||
Awaitable[types.Completion | None],
|
||||
],
|
||||
):
|
||||
logger.debug("Registering handler for CompleteRequest")
|
||||
|
||||
async def handler(req: types.CompleteRequest):
|
||||
completion = await func(req.params.ref, req.params.argument)
|
||||
return types.ServerResult(
|
||||
types.CompleteResult(
|
||||
completion=completion
|
||||
if completion is not None
|
||||
else types.Completion(values=[], total=None, hasMore=None),
|
||||
)
|
||||
)
|
||||
|
||||
self.request_handlers[types.CompleteRequest] = handler
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def run(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
|
||||
initialization_options: InitializationOptions,
|
||||
# When False, exceptions are returned as messages to the client.
|
||||
# When True, exceptions are raised, which will cause the server to shut down
|
||||
# but also make tracing exceptions much easier during testing and when using
|
||||
# in-process servers.
|
||||
raise_exceptions: bool = False,
|
||||
):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
async with ServerSession(
|
||||
read_stream, write_stream, initialization_options
|
||||
) as session:
|
||||
async for message in session.incoming_messages:
|
||||
logger.debug(f"Received message: {message}")
|
||||
|
||||
match message:
|
||||
case RequestResponder(request=types.ClientRequest(root=req)):
|
||||
logger.info(
|
||||
f"Processing request of type {type(req).__name__}"
|
||||
)
|
||||
if type(req) in self.request_handlers:
|
||||
handler = self.request_handlers[type(req)]
|
||||
logger.debug(
|
||||
f"Dispatching request of type {type(req).__name__}"
|
||||
)
|
||||
|
||||
token = None
|
||||
try:
|
||||
# Set our global state that can be retrieved via
|
||||
# app.get_request_context()
|
||||
token = request_ctx.set(
|
||||
RequestContext(
|
||||
message.request_id,
|
||||
message.request_meta,
|
||||
session,
|
||||
)
|
||||
)
|
||||
response = await handler(req)
|
||||
except McpError as err:
|
||||
response = err.error
|
||||
except Exception as err:
|
||||
if raise_exceptions:
|
||||
raise err
|
||||
response = types.ErrorData(
|
||||
code=0, message=str(err), data=None
|
||||
)
|
||||
finally:
|
||||
# Reset the global state after we are done
|
||||
if token is not None:
|
||||
request_ctx.reset(token)
|
||||
|
||||
await message.respond(response)
|
||||
else:
|
||||
await message.respond(
|
||||
types.ErrorData(
|
||||
code=types.METHOD_NOT_FOUND,
|
||||
message="Method not found",
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("Response sent")
|
||||
case types.ClientNotification(root=notify):
|
||||
if type(notify) in self.notification_handlers:
|
||||
assert type(notify) in self.notification_handlers
|
||||
|
||||
handler = self.notification_handlers[type(notify)]
|
||||
logger.debug(
|
||||
f"Dispatching notification of type "
|
||||
f"{type(notify).__name__}"
|
||||
)
|
||||
|
||||
try:
|
||||
await handler(notify)
|
||||
except Exception as err:
|
||||
logger.error(
|
||||
f"Uncaught exception in notification handler: "
|
||||
f"{err}"
|
||||
)
|
||||
|
||||
for warning in w:
|
||||
logger.info(
|
||||
f"Warning: {warning.category.__name__}: {warning.message}"
|
||||
)
|
||||
|
||||
|
||||
async def _ping_handler(request: types.PingRequest) -> types.ServerResult:
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
Reference in New Issue
Block a user