mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-23 16:54:24 +01:00
Use 120 characters instead of 88 (#856)
This commit is contained in:
committed by
GitHub
parent
f7265f7b91
commit
543961968c
@@ -42,13 +42,9 @@ class AssistantMessage(Message):
|
||||
super().__init__(content=content, **kwargs)
|
||||
|
||||
|
||||
message_validator = TypeAdapter[UserMessage | AssistantMessage](
|
||||
UserMessage | AssistantMessage
|
||||
)
|
||||
message_validator = TypeAdapter[UserMessage | AssistantMessage](UserMessage | AssistantMessage)
|
||||
|
||||
SyncPromptResult = (
|
||||
str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
|
||||
)
|
||||
SyncPromptResult = str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
|
||||
PromptResult = SyncPromptResult | Awaitable[SyncPromptResult]
|
||||
|
||||
|
||||
@@ -56,24 +52,16 @@ class PromptArgument(BaseModel):
|
||||
"""An argument that can be passed to a prompt."""
|
||||
|
||||
name: str = Field(description="Name of the argument")
|
||||
description: str | None = Field(
|
||||
None, description="Description of what the argument does"
|
||||
)
|
||||
required: bool = Field(
|
||||
default=False, description="Whether the argument is required"
|
||||
)
|
||||
description: str | None = Field(None, description="Description of what the argument does")
|
||||
required: bool = Field(default=False, description="Whether the argument is required")
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
"""A prompt template that can be rendered with parameters."""
|
||||
|
||||
name: str = Field(description="Name of the prompt")
|
||||
description: str | None = Field(
|
||||
None, description="Description of what the prompt does"
|
||||
)
|
||||
arguments: list[PromptArgument] | None = Field(
|
||||
None, description="Arguments that can be passed to the prompt"
|
||||
)
|
||||
description: str | None = Field(None, description="Description of what the prompt does")
|
||||
arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt")
|
||||
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
|
||||
|
||||
@classmethod
|
||||
@@ -154,14 +142,10 @@ class Prompt(BaseModel):
|
||||
content = TextContent(type="text", text=msg)
|
||||
messages.append(UserMessage(content=content))
|
||||
else:
|
||||
content = pydantic_core.to_json(
|
||||
msg, fallback=str, indent=2
|
||||
).decode()
|
||||
content = pydantic_core.to_json(msg, fallback=str, indent=2).decode()
|
||||
messages.append(Message(role="user", content=content))
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Could not convert prompt result to message: {msg}"
|
||||
)
|
||||
raise ValueError(f"Could not convert prompt result to message: {msg}")
|
||||
|
||||
return messages
|
||||
except Exception as e:
|
||||
|
||||
@@ -39,9 +39,7 @@ class PromptManager:
|
||||
self._prompts[prompt.name] = prompt
|
||||
return prompt
|
||||
|
||||
async def render_prompt(
|
||||
self, name: str, arguments: dict[str, Any] | None = None
|
||||
) -> list[Message]:
|
||||
async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]:
|
||||
"""Render a prompt by name with arguments."""
|
||||
prompt = self.get_prompt(name)
|
||||
if not prompt:
|
||||
|
||||
@@ -19,13 +19,9 @@ class Resource(BaseModel, abc.ABC):
|
||||
|
||||
model_config = ConfigDict(validate_default=True)
|
||||
|
||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(
|
||||
default=..., description="URI of the resource"
|
||||
)
|
||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(default=..., description="URI of the resource")
|
||||
name: str | None = Field(description="Name of the resource", default=None)
|
||||
description: str | None = Field(
|
||||
description="Description of the resource", default=None
|
||||
)
|
||||
description: str | None = Field(description="Description of the resource", default=None)
|
||||
mime_type: str = Field(
|
||||
default="text/plain",
|
||||
description="MIME type of the resource content",
|
||||
|
||||
@@ -15,18 +15,12 @@ from mcp.server.fastmcp.resources.types import FunctionResource, Resource
|
||||
class ResourceTemplate(BaseModel):
|
||||
"""A template for dynamically creating resources."""
|
||||
|
||||
uri_template: str = Field(
|
||||
description="URI template with parameters (e.g. weather://{city}/current)"
|
||||
)
|
||||
uri_template: str = Field(description="URI template with parameters (e.g. weather://{city}/current)")
|
||||
name: str = Field(description="Name of the resource")
|
||||
description: str | None = Field(description="Description of what the resource does")
|
||||
mime_type: str = Field(
|
||||
default="text/plain", description="MIME type of the resource content"
|
||||
)
|
||||
mime_type: str = Field(default="text/plain", description="MIME type of the resource content")
|
||||
fn: Callable[..., Any] = Field(exclude=True)
|
||||
parameters: dict[str, Any] = Field(
|
||||
description="JSON schema for function parameters"
|
||||
)
|
||||
parameters: dict[str, Any] = Field(description="JSON schema for function parameters")
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
|
||||
@@ -54,9 +54,7 @@ class FunctionResource(Resource):
|
||||
async def read(self) -> str | bytes:
|
||||
"""Read the resource by calling the wrapped function."""
|
||||
try:
|
||||
result = (
|
||||
await self.fn() if inspect.iscoroutinefunction(self.fn) else self.fn()
|
||||
)
|
||||
result = await self.fn() if inspect.iscoroutinefunction(self.fn) else self.fn()
|
||||
if isinstance(result, Resource):
|
||||
return await result.read()
|
||||
elif isinstance(result, bytes):
|
||||
@@ -141,9 +139,7 @@ class HttpResource(Resource):
|
||||
"""A resource that reads from an HTTP endpoint."""
|
||||
|
||||
url: str = Field(description="URL to fetch content from")
|
||||
mime_type: str = Field(
|
||||
default="application/json", description="MIME type of the resource content"
|
||||
)
|
||||
mime_type: str = Field(default="application/json", description="MIME type of the resource content")
|
||||
|
||||
async def read(self) -> str | bytes:
|
||||
"""Read the HTTP content."""
|
||||
@@ -157,15 +153,9 @@ class DirectoryResource(Resource):
|
||||
"""A resource that lists files in a directory."""
|
||||
|
||||
path: Path = Field(description="Path to the directory")
|
||||
recursive: bool = Field(
|
||||
default=False, description="Whether to list files recursively"
|
||||
)
|
||||
pattern: str | None = Field(
|
||||
default=None, description="Optional glob pattern to filter files"
|
||||
)
|
||||
mime_type: str = Field(
|
||||
default="application/json", description="MIME type of the resource content"
|
||||
)
|
||||
recursive: bool = Field(default=False, description="Whether to list files recursively")
|
||||
pattern: str | None = Field(default=None, description="Optional glob pattern to filter files")
|
||||
mime_type: str = Field(default="application/json", description="MIME type of the resource content")
|
||||
|
||||
@pydantic.field_validator("path")
|
||||
@classmethod
|
||||
@@ -184,16 +174,8 @@ class DirectoryResource(Resource):
|
||||
|
||||
try:
|
||||
if self.pattern:
|
||||
return (
|
||||
list(self.path.glob(self.pattern))
|
||||
if not self.recursive
|
||||
else list(self.path.rglob(self.pattern))
|
||||
)
|
||||
return (
|
||||
list(self.path.glob("*"))
|
||||
if not self.recursive
|
||||
else list(self.path.rglob("*"))
|
||||
)
|
||||
return list(self.path.glob(self.pattern)) if not self.recursive else list(self.path.rglob(self.pattern))
|
||||
return list(self.path.glob("*")) if not self.recursive else list(self.path.rglob("*"))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error listing directory {self.path}: {e}")
|
||||
|
||||
|
||||
@@ -97,9 +97,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
||||
|
||||
# StreamableHTTP settings
|
||||
json_response: bool = False
|
||||
stateless_http: bool = (
|
||||
False # If True, uses true stateless mode (new transport per request)
|
||||
)
|
||||
stateless_http: bool = False # If True, uses true stateless mode (new transport per request)
|
||||
|
||||
# resource settings
|
||||
warn_on_duplicate_resources: bool = True
|
||||
@@ -115,9 +113,9 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
||||
description="List of dependencies to install in the server environment",
|
||||
)
|
||||
|
||||
lifespan: (
|
||||
Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None
|
||||
) = Field(None, description="Lifespan context manager")
|
||||
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None = Field(
|
||||
None, description="Lifespan context manager"
|
||||
)
|
||||
|
||||
auth: AuthSettings | None = None
|
||||
|
||||
@@ -125,9 +123,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
|
||||
def lifespan_wrapper(
|
||||
app: FastMCP,
|
||||
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
|
||||
) -> Callable[
|
||||
[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]
|
||||
]:
|
||||
) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]]:
|
||||
@asynccontextmanager
|
||||
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
|
||||
async with lifespan(app) as context:
|
||||
@@ -141,8 +137,7 @@ class FastMCP:
|
||||
self,
|
||||
name: str | None = None,
|
||||
instructions: str | None = None,
|
||||
auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
| None = None,
|
||||
auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None,
|
||||
event_store: EventStore | None = None,
|
||||
*,
|
||||
tools: list[Tool] | None = None,
|
||||
@@ -153,31 +148,18 @@ class FastMCP:
|
||||
self._mcp_server = MCPServer(
|
||||
name=name or "FastMCP",
|
||||
instructions=instructions,
|
||||
lifespan=(
|
||||
lifespan_wrapper(self, self.settings.lifespan)
|
||||
if self.settings.lifespan
|
||||
else default_lifespan
|
||||
),
|
||||
)
|
||||
self._tool_manager = ToolManager(
|
||||
tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
|
||||
)
|
||||
self._resource_manager = ResourceManager(
|
||||
warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources
|
||||
)
|
||||
self._prompt_manager = PromptManager(
|
||||
warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts
|
||||
lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan),
|
||||
)
|
||||
self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools)
|
||||
self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources)
|
||||
self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts)
|
||||
if (self.settings.auth is not None) != (auth_server_provider is not None):
|
||||
# TODO: after we support separate authorization servers (see
|
||||
# https://github.com/modelcontextprotocol/modelcontextprotocol/pull/284)
|
||||
# we should validate that if auth is enabled, we have either an
|
||||
# auth_server_provider to host our own authorization server,
|
||||
# OR the URL of a 3rd party authorization server.
|
||||
raise ValueError(
|
||||
"settings.auth must be specified if and only if auth_server_provider "
|
||||
"is specified"
|
||||
)
|
||||
raise ValueError("settings.auth must be specified if and only if auth_server_provider " "is specified")
|
||||
self._auth_server_provider = auth_server_provider
|
||||
self._event_store = event_store
|
||||
self._custom_starlette_routes: list[Route] = []
|
||||
@@ -340,9 +322,7 @@ class FastMCP:
|
||||
description: Optional description of what the tool does
|
||||
annotations: Optional ToolAnnotations providing additional tool information
|
||||
"""
|
||||
self._tool_manager.add_tool(
|
||||
fn, name=name, description=description, annotations=annotations
|
||||
)
|
||||
self._tool_manager.add_tool(fn, name=name, description=description, annotations=annotations)
|
||||
|
||||
def tool(
|
||||
self,
|
||||
@@ -379,14 +359,11 @@ class FastMCP:
|
||||
# Check if user passed function directly instead of calling decorator
|
||||
if callable(name):
|
||||
raise TypeError(
|
||||
"The @tool decorator was used incorrectly. "
|
||||
"Did you forget to call it? Use @tool() instead of @tool"
|
||||
"The @tool decorator was used incorrectly. " "Did you forget to call it? Use @tool() instead of @tool"
|
||||
)
|
||||
|
||||
def decorator(fn: AnyFunction) -> AnyFunction:
|
||||
self.add_tool(
|
||||
fn, name=name, description=description, annotations=annotations
|
||||
)
|
||||
self.add_tool(fn, name=name, description=description, annotations=annotations)
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
@@ -462,8 +439,7 @@ class FastMCP:
|
||||
|
||||
if uri_params != func_params:
|
||||
raise ValueError(
|
||||
f"Mismatch between URI parameters {uri_params} "
|
||||
f"and function parameters {func_params}"
|
||||
f"Mismatch between URI parameters {uri_params} " f"and function parameters {func_params}"
|
||||
)
|
||||
|
||||
# Register as template
|
||||
@@ -496,9 +472,7 @@ class FastMCP:
|
||||
"""
|
||||
self._prompt_manager.add_prompt(prompt)
|
||||
|
||||
def prompt(
|
||||
self, name: str | None = None, description: str | None = None
|
||||
) -> Callable[[AnyFunction], AnyFunction]:
|
||||
def prompt(self, name: str | None = None, description: str | None = None) -> Callable[[AnyFunction], AnyFunction]:
|
||||
"""Decorator to register a prompt.
|
||||
|
||||
Args:
|
||||
@@ -665,9 +639,7 @@ class FastMCP:
|
||||
self.settings.mount_path = mount_path
|
||||
|
||||
# Create normalized endpoint considering the mount path
|
||||
normalized_message_endpoint = self._normalize_path(
|
||||
self.settings.mount_path, self.settings.message_path
|
||||
)
|
||||
normalized_message_endpoint = self._normalize_path(self.settings.mount_path, self.settings.message_path)
|
||||
|
||||
# Set up auth context and dependencies
|
||||
|
||||
@@ -764,9 +736,7 @@ class FastMCP:
|
||||
routes.extend(self._custom_starlette_routes)
|
||||
|
||||
# Create Starlette app with routes and middleware
|
||||
return Starlette(
|
||||
debug=self.settings.debug, routes=routes, middleware=middleware
|
||||
)
|
||||
return Starlette(debug=self.settings.debug, routes=routes, middleware=middleware)
|
||||
|
||||
def streamable_http_app(self) -> Starlette:
|
||||
"""Return an instance of the StreamableHTTP server app."""
|
||||
@@ -783,9 +753,7 @@ class FastMCP:
|
||||
)
|
||||
|
||||
# Create the ASGI handler
|
||||
async def handle_streamable_http(
|
||||
scope: Scope, receive: Receive, send: Send
|
||||
) -> None:
|
||||
async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await self.session_manager.handle_request(scope, receive, send)
|
||||
|
||||
# Create routes
|
||||
@@ -861,9 +829,7 @@ class FastMCP:
|
||||
for prompt in prompts
|
||||
]
|
||||
|
||||
async def get_prompt(
|
||||
self, name: str, arguments: dict[str, Any] | None = None
|
||||
) -> GetPromptResult:
|
||||
async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult:
|
||||
"""Get a prompt by name with arguments."""
|
||||
try:
|
||||
messages = await self._prompt_manager.render_prompt(name, arguments)
|
||||
@@ -936,9 +902,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
request_context: (
|
||||
RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
|
||||
) = None,
|
||||
request_context: (RequestContext[ServerSessionT, LifespanContextT, RequestT] | None) = None,
|
||||
fastmcp: FastMCP | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
@@ -962,9 +926,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
||||
raise ValueError("Context is not available outside of a request")
|
||||
return self._request_context
|
||||
|
||||
async def report_progress(
|
||||
self, progress: float, total: float | None = None, message: str | None = None
|
||||
) -> None:
|
||||
async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None:
|
||||
"""Report progress for the current operation.
|
||||
|
||||
Args:
|
||||
@@ -972,11 +934,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
||||
total: Optional total value e.g. 100
|
||||
message: Optional message e.g. Starting render...
|
||||
"""
|
||||
progress_token = (
|
||||
self.request_context.meta.progressToken
|
||||
if self.request_context.meta
|
||||
else None
|
||||
)
|
||||
progress_token = self.request_context.meta.progressToken if self.request_context.meta else None
|
||||
|
||||
if progress_token is None:
|
||||
return
|
||||
@@ -997,9 +955,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
||||
Returns:
|
||||
The resource content as either text or bytes
|
||||
"""
|
||||
assert (
|
||||
self._fastmcp is not None
|
||||
), "Context is not available outside of a request"
|
||||
assert self._fastmcp is not None, "Context is not available outside of a request"
|
||||
return await self._fastmcp.read_resource(uri)
|
||||
|
||||
async def log(
|
||||
@@ -1027,11 +983,7 @@ class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
||||
@property
|
||||
def client_id(self) -> str | None:
|
||||
"""Get the client ID if available."""
|
||||
return (
|
||||
getattr(self.request_context.meta, "client_id", None)
|
||||
if self.request_context.meta
|
||||
else None
|
||||
)
|
||||
return getattr(self.request_context.meta, "client_id", None) if self.request_context.meta else None
|
||||
|
||||
@property
|
||||
def request_id(self) -> str:
|
||||
|
||||
@@ -25,16 +25,11 @@ class Tool(BaseModel):
|
||||
description: str = Field(description="Description of what the tool does")
|
||||
parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
|
||||
fn_metadata: FuncMetadata = Field(
|
||||
description="Metadata about the function including a pydantic model for tool"
|
||||
" arguments"
|
||||
description="Metadata about the function including a pydantic model for tool" " arguments"
|
||||
)
|
||||
is_async: bool = Field(description="Whether the tool is async")
|
||||
context_kwarg: str | None = Field(
|
||||
None, description="Name of the kwarg that should receive context"
|
||||
)
|
||||
annotations: ToolAnnotations | None = Field(
|
||||
None, description="Optional annotations for the tool"
|
||||
)
|
||||
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
|
||||
annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool")
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
@@ -93,9 +88,7 @@ class Tool(BaseModel):
|
||||
self.fn,
|
||||
self.is_async,
|
||||
arguments,
|
||||
{self.context_kwarg: context}
|
||||
if self.context_kwarg is not None
|
||||
else None,
|
||||
{self.context_kwarg: context} if self.context_kwarg is not None else None,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Error executing tool {self.name}: {e}") from e
|
||||
|
||||
@@ -50,9 +50,7 @@ class ToolManager:
|
||||
annotations: ToolAnnotations | None = None,
|
||||
) -> Tool:
|
||||
"""Add a tool to the server."""
|
||||
tool = Tool.from_function(
|
||||
fn, name=name, description=description, annotations=annotations
|
||||
)
|
||||
tool = Tool.from_function(fn, name=name, description=description, annotations=annotations)
|
||||
existing = self._tools.get(tool.name)
|
||||
if existing:
|
||||
if self.warn_on_duplicate_tools:
|
||||
|
||||
@@ -102,9 +102,7 @@ class FuncMetadata(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
def func_metadata(
|
||||
func: Callable[..., Any], skip_names: Sequence[str] = ()
|
||||
) -> FuncMetadata:
|
||||
def func_metadata(func: Callable[..., Any], skip_names: Sequence[str] = ()) -> FuncMetadata:
|
||||
"""Given a function, return metadata including a pydantic model representing its
|
||||
signature.
|
||||
|
||||
@@ -131,9 +129,7 @@ def func_metadata(
|
||||
globalns = getattr(func, "__globals__", {})
|
||||
for param in params.values():
|
||||
if param.name.startswith("_"):
|
||||
raise InvalidSignature(
|
||||
f"Parameter {param.name} of {func.__name__} cannot start with '_'"
|
||||
)
|
||||
raise InvalidSignature(f"Parameter {param.name} of {func.__name__} cannot start with '_'")
|
||||
if param.name in skip_names:
|
||||
continue
|
||||
annotation = param.annotation
|
||||
@@ -142,11 +138,7 @@ def func_metadata(
|
||||
if annotation is None:
|
||||
annotation = Annotated[
|
||||
None,
|
||||
Field(
|
||||
default=param.default
|
||||
if param.default is not inspect.Parameter.empty
|
||||
else PydanticUndefined
|
||||
),
|
||||
Field(default=param.default if param.default is not inspect.Parameter.empty else PydanticUndefined),
|
||||
]
|
||||
|
||||
# Untyped field
|
||||
@@ -160,9 +152,7 @@ def func_metadata(
|
||||
|
||||
field_info = FieldInfo.from_annotated_attribute(
|
||||
_get_typed_annotation(annotation, globalns),
|
||||
param.default
|
||||
if param.default is not inspect.Parameter.empty
|
||||
else PydanticUndefined,
|
||||
param.default if param.default is not inspect.Parameter.empty else PydanticUndefined,
|
||||
)
|
||||
dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info)
|
||||
continue
|
||||
@@ -177,9 +167,7 @@ def func_metadata(
|
||||
|
||||
|
||||
def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
|
||||
def try_eval_type(
|
||||
value: Any, globalns: dict[str, Any], localns: dict[str, Any]
|
||||
) -> tuple[Any, bool]:
|
||||
def try_eval_type(value: Any, globalns: dict[str, Any], localns: dict[str, Any]) -> tuple[Any, bool]:
|
||||
try:
|
||||
return eval_type_backport(value, globalns, localns), True
|
||||
except NameError:
|
||||
|
||||
Reference in New Issue
Block a user