Remove helper types

The helper types in mcp.server.types got really confusioning during
implementation as they overlapped with mcp.types. I now believe it
is better if we stay more low level to the spec types.

To do this, we now only use mcp.types everywhere. We renamed mcp.server.types
to mcp.server.models and removed it to the absolute minimum.
This commit is contained in:
David Soria Parra
2024-11-11 20:05:51 +00:00
parent 837309c3c8
commit f5d82bd229
8 changed files with 40 additions and 98 deletions

View File

@@ -7,7 +7,7 @@ from typing import Any, Sequence
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl from pydantic import AnyUrl
from mcp.server import types from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server as stdio_server from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext from mcp.shared.context import RequestContext
@@ -15,6 +15,10 @@ from mcp.shared.session import RequestResponder
from mcp.types import ( from mcp.types import (
METHOD_NOT_FOUND, METHOD_NOT_FOUND,
CallToolRequest, CallToolRequest,
GetPromptResult,
GetPromptRequest,
GetPromptResult,
ImageContent,
ClientNotification, ClientNotification,
ClientRequest, ClientRequest,
CompleteRequest, CompleteRequest,
@@ -84,7 +88,7 @@ class Server:
self, self,
notification_options: NotificationOptions | None = None, notification_options: NotificationOptions | None = None,
experimental_capabilities: dict[str, dict[str, Any]] | None = None, experimental_capabilities: dict[str, dict[str, Any]] | None = None,
) -> types.InitializationOptions: ) -> InitializationOptions:
"""Create initialization options from this server instance.""" """Create initialization options from this server instance."""
def pkg_version(package: str) -> str: def pkg_version(package: str) -> str:
@@ -99,7 +103,7 @@ class Server:
return "unknown" return "unknown"
return types.InitializationOptions( return InitializationOptions(
server_name=self.name, server_name=self.name,
server_version=pkg_version("mcp"), server_version=pkg_version("mcp"),
capabilities=self.get_capabilities( capabilities=self.get_capabilities(
@@ -168,50 +172,16 @@ class Server:
return decorator return decorator
def get_prompt(self): def get_prompt(self):
from mcp.types import (
GetPromptRequest,
GetPromptResult,
ImageContent,
)
from mcp.types import (
Role as Role,
)
def decorator( def decorator(
func: Callable[ func: Callable[
[str, dict[str, str] | None], Awaitable[types.PromptResponse] [str, dict[str, str] | None], Awaitable[GetPromptResult]
], ],
): ):
logger.debug("Registering handler for GetPromptRequest") logger.debug("Registering handler for GetPromptRequest")
async def handler(req: GetPromptRequest): async def handler(req: GetPromptRequest):
prompt_get = await func(req.params.name, req.params.arguments) prompt_get = await func(req.params.name, req.params.arguments)
messages: list[PromptMessage] = [] return ServerResult(prompt_get)
for message in prompt_get.messages:
match message.content:
case str() as text_content:
content = TextContent(type="text", text=text_content)
case types.ImageContent() as img_content:
content = ImageContent(
type="image",
data=img_content.data,
mimeType=img_content.mime_type,
)
case types.EmbeddedResource() as resource:
content = EmbeddedResource(
type="resource", resource=resource.resource
)
case _:
raise ValueError(
f"Unexpected content type: {type(message.content)}"
)
prompt_message = PromptMessage(role=message.role, content=content)
messages.append(prompt_message)
return ServerResult(
GetPromptResult(description=prompt_get.desc, messages=messages)
)
self.request_handlers[GetPromptRequest] = handler self.request_handlers[GetPromptRequest] = handler
return func return func
@@ -338,7 +308,7 @@ class Server:
def decorator( def decorator(
func: Callable[ func: Callable[
..., ...,
Awaitable[Sequence[str | types.ImageContent | types.EmbeddedResource]], Awaitable[Sequence[TextContent | ImageContent | EmbeddedResource]],
], ],
): ):
logger.debug("Registering handler for CallToolRequest") logger.debug("Registering handler for CallToolRequest")
@@ -351,15 +321,15 @@ class Server:
match result: match result:
case str() as text: case str() as text:
content.append(TextContent(type="text", text=text)) content.append(TextContent(type="text", text=text))
case types.ImageContent() as img: case ImageContent() as img:
content.append( content.append(
ImageContent( ImageContent(
type="image", type="image",
data=img.data, data=img.data,
mimeType=img.mime_type, mimeType=img.mimeType,
) )
) )
case types.EmbeddedResource() as resource: case EmbeddedResource() as resource:
content.append( content.append(
EmbeddedResource( EmbeddedResource(
type="resource", resource=resource.resource type="resource", resource=resource.resource
@@ -427,7 +397,7 @@ class Server:
self, self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage], write_stream: MemoryObjectSendStream[JSONRPCMessage],
initialization_options: types.InitializationOptions, initialization_options: InitializationOptions,
# When True, exceptions are returned as messages to the client. # When True, exceptions are returned as messages to the client.
# When False, exceptions are raised, which will cause the server to shut down # When False, exceptions are raised, which will cause the server to shut down
# but also make tracing exceptions much easier during testing and when using # but also make tracing exceptions much easier during testing and when using

View File

@@ -6,7 +6,7 @@ import anyio
from mcp.server.session import ServerSession from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server from mcp.server.stdio import stdio_server
from mcp.server.types import InitializationOptions from mcp.server.models import InitializationOptions
from mcp.types import ServerCapabilities from mcp.types import ServerCapabilities
if not sys.warnoptions: if not sys.warnoptions:

19
src/mcp/server/models.py Normal file
View File

@@ -0,0 +1,19 @@
"""
This module provides simpler types to use with the server for managing prompts
and tools.
"""
from dataclasses import dataclass
from typing import Literal
from pydantic import BaseModel
from mcp.types import (
ServerCapabilities,
)
class InitializationOptions(BaseModel):
server_name: str
server_version: str
capabilities: ServerCapabilities

View File

@@ -6,7 +6,7 @@ import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl from pydantic import AnyUrl
from mcp.server.types import InitializationOptions from mcp.server.models import InitializationOptions
from mcp.shared.session import ( from mcp.shared.session import (
BaseSession, BaseSession,
RequestResponder, RequestResponder,

View File

@@ -1,46 +0,0 @@
"""
This module provides simpler types to use with the server for managing prompts
and tools.
"""
from dataclasses import dataclass
from typing import Literal
from pydantic import BaseModel
from mcp.types import (
BlobResourceContents,
Role,
ServerCapabilities,
TextResourceContents,
)
@dataclass
class ImageContent:
type: Literal["image"]
data: str
mime_type: str
@dataclass
class EmbeddedResource:
resource: TextResourceContents | BlobResourceContents
@dataclass
class Message:
role: Role
content: str | ImageContent | EmbeddedResource
@dataclass
class PromptResponse:
messages: list[Message]
desc: str | None = None
class InitializationOptions(BaseModel):
server_name: str
server_version: str
capabilities: ServerCapabilities

View File

@@ -2,7 +2,7 @@ import pytest
from pydantic import AnyUrl from pydantic import AnyUrl
from mcp.server import Server from mcp.server import Server
from mcp.server.types import InitializationOptions from mcp.server.models import InitializationOptions
from mcp.types import Resource, ServerCapabilities from mcp.types import Resource, ServerCapabilities
TEST_INITIALIZATION_OPTIONS = InitializationOptions( TEST_INITIALIZATION_OPTIONS = InitializationOptions(

View File

@@ -4,7 +4,7 @@ import pytest
from mcp.client.session import ClientSession from mcp.client.session import ClientSession
from mcp.server import NotificationOptions, Server from mcp.server import NotificationOptions, Server
from mcp.server.session import ServerSession from mcp.server.session import ServerSession
from mcp.server.types import InitializationOptions from mcp.server.models import InitializationOptions
from mcp.types import ( from mcp.types import (
ClientNotification, ClientNotification,
InitializedNotification, InitializedNotification,

7
uv.lock generated
View File

@@ -331,15 +331,14 @@ wheels = [
[[package]] [[package]]
name = "pyright" name = "pyright"
version = "1.1.388" version = "1.1.378"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "nodeenv" }, { name = "nodeenv" },
{ name = "typing-extensions" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/9c/83/e9867538a794638d2d20ac3ab3106a31aca1d9cfea530c9b2921809dae03/pyright-1.1.388.tar.gz", hash = "sha256:0166d19b716b77fd2d9055de29f71d844874dbc6b9d3472ccd22df91db3dfa34", size = 21939 } sdist = { url = "https://files.pythonhosted.org/packages/3d/f0/e8aa5555410d88f898bef04da2102b0a9bf144658c98d34872e91621ced2/pyright-1.1.378.tar.gz", hash = "sha256:78a043be2876d12d0af101d667e92c7734f3ebb9db71dccc2c220e7e7eb89ca2", size = 17486 }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/03/57/7fb00363b7f267a398c5bdf4f55f3e64f7c2076b2e7d2901b3373d52b6ff/pyright-1.1.388-py3-none-any.whl", hash = "sha256:c7068e9f2c23539c6ac35fc9efac6c6c1b9aa5a0ce97a9a8a6cf0090d7cbf84c", size = 18579 }, { url = "https://files.pythonhosted.org/packages/38/c6/f0d4bc20c13b20cecfbf13c699477c825e45767f1dc5068137323f86e495/pyright-1.1.378-py3-none-any.whl", hash = "sha256:8853776138b01bc284da07ac481235be7cc89d3176b073d2dba73636cb95be79", size = 18222 },
] ]
[[package]] [[package]]