mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-19 06:54:18 +01:00
refactor: modernize type hints and improve async context handling
- Update type hints to use Python 3.10 syntax (dict, list, X | None) - Replace requests with httpx for HTTP client consistency - Improve async context management using AsyncExitStack - Simplify server cleanup method
This commit is contained in:
@@ -3,9 +3,10 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Optional
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
@@ -30,7 +31,7 @@ class Configuration:
|
||||
load_dotenv()
|
||||
|
||||
@staticmethod
|
||||
def load_config(file_path: str) -> Dict[str, Any]:
|
||||
def load_config(file_path: str) -> dict[str, Any]:
|
||||
"""Load server configuration from JSON file.
|
||||
|
||||
Args:
|
||||
@@ -64,12 +65,13 @@ class Configuration:
|
||||
class Server:
|
||||
"""Manages MCP server connections and tool execution."""
|
||||
|
||||
def __init__(self, name: str, config: Dict[str, Any]) -> None:
|
||||
def __init__(self, name: str, config: dict[str, Any]) -> None:
|
||||
self.name: str = name
|
||||
self.config: Dict[str, Any] = config
|
||||
self.stdio_context: Optional[Any] = None
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.config: dict[str, Any] = config
|
||||
self.stdio_context: Any | None = None
|
||||
self.session: ClientSession | None = None
|
||||
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
|
||||
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the server connection."""
|
||||
@@ -89,17 +91,16 @@ class Server:
|
||||
else None,
|
||||
)
|
||||
try:
|
||||
self.stdio_context = stdio_client(server_params)
|
||||
read, write = await self.stdio_context.__aenter__()
|
||||
self.session = ClientSession(read, write)
|
||||
await self.session.__aenter__()
|
||||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||||
read, write = stdio_transport
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
||||
await self.session.initialize()
|
||||
except Exception as e:
|
||||
logging.error(f"Error initializing server {self.name}: {e}")
|
||||
await self.cleanup()
|
||||
raise
|
||||
|
||||
async def list_tools(self) -> List[Any]:
|
||||
async def list_tools(self) -> list[Any]:
|
||||
"""List available tools from the server.
|
||||
|
||||
Returns:
|
||||
@@ -124,7 +125,7 @@ class Server:
|
||||
async def execute_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
arguments: dict[str, Any],
|
||||
retries: int = 2,
|
||||
delay: float = 1.0,
|
||||
) -> Any:
|
||||
@@ -170,29 +171,9 @@ class Server:
|
||||
"""Clean up server resources."""
|
||||
async with self._cleanup_lock:
|
||||
try:
|
||||
if self.session:
|
||||
try:
|
||||
await self.session.__aexit__(None, None, None)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Warning during session cleanup for {self.name}: {e}"
|
||||
)
|
||||
finally:
|
||||
self.session = None
|
||||
|
||||
if self.stdio_context:
|
||||
try:
|
||||
await self.stdio_context.__aexit__(None, None, None)
|
||||
except (RuntimeError, asyncio.CancelledError) as e:
|
||||
logging.info(
|
||||
f"Note: Normal shutdown message for {self.name}: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Warning during stdio cleanup for {self.name}: {e}"
|
||||
)
|
||||
finally:
|
||||
self.stdio_context = None
|
||||
await self.exit_stack.aclose()
|
||||
self.session = None
|
||||
self.stdio_context = None
|
||||
except Exception as e:
|
||||
logging.error(f"Error during cleanup of server {self.name}: {e}")
|
||||
|
||||
@@ -201,11 +182,11 @@ class Tool:
|
||||
"""Represents a tool with its properties and formatting."""
|
||||
|
||||
def __init__(
|
||||
self, name: str, description: str, input_schema: Dict[str, Any]
|
||||
self, name: str, description: str, input_schema: dict[str, Any]
|
||||
) -> None:
|
||||
self.name: str = name
|
||||
self.description: str = description
|
||||
self.input_schema: Dict[str, Any] = input_schema
|
||||
self.input_schema: dict[str, Any] = input_schema
|
||||
|
||||
def format_for_llm(self) -> str:
|
||||
"""Format tool information for LLM.
|
||||
@@ -237,7 +218,7 @@ class LLMClient:
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key: str = api_key
|
||||
|
||||
def get_response(self, messages: List[Dict[str, str]]) -> str:
|
||||
def get_response(self, messages: list[dict[str, str]]) -> str:
|
||||
"""Get a response from the LLM.
|
||||
|
||||
Args:
|
||||
@@ -247,7 +228,7 @@ class LLMClient:
|
||||
The LLM's response as a string.
|
||||
|
||||
Raises:
|
||||
RequestException: If the request to the LLM fails.
|
||||
httpx.RequestError: If the request to the LLM fails.
|
||||
"""
|
||||
url = "https://api.groq.com/openai/v1/chat/completions"
|
||||
|
||||
@@ -266,16 +247,17 @@ class LLMClient:
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data["choices"][0]["message"]["content"]
|
||||
with httpx.Client() as client:
|
||||
response = client.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data["choices"][0]["message"]["content"]
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
except httpx.RequestError as e:
|
||||
error_message = f"Error getting LLM response: {str(e)}"
|
||||
logging.error(error_message)
|
||||
|
||||
if e.response is not None:
|
||||
if hasattr(e, 'response'):
|
||||
status_code = e.response.status_code
|
||||
logging.error(f"Status code: {status_code}")
|
||||
logging.error(f"Response details: {e.response.text}")
|
||||
@@ -289,8 +271,8 @@ class LLMClient:
|
||||
class ChatSession:
|
||||
"""Orchestrates the interaction between user, LLM, and tools."""
|
||||
|
||||
def __init__(self, servers: List[Server], llm_client: LLMClient) -> None:
|
||||
self.servers: List[Server] = servers
|
||||
def __init__(self, servers: list[Server], llm_client: LLMClient) -> None:
|
||||
self.servers: list[Server] = servers
|
||||
self.llm_client: LLMClient = llm_client
|
||||
|
||||
async def cleanup_servers(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user