refactor: modernize type hints and improve async context handling

- Update type hints to use Python 3.10 syntax (dict, list, X | None)
- Replace requests with httpx for HTTP client consistency
- Improve async context management using AsyncExitStack
- Simplify server cleanup method
This commit is contained in:
3choff
2024-12-18 16:35:49 +00:00
parent a0216c3e50
commit 466e1e8eb7

View File

@@ -3,9 +3,10 @@ import json
import logging
import os
import shutil
from typing import Any, Dict, List, Optional
from contextlib import AsyncExitStack
from typing import Any
import requests
import httpx
from dotenv import load_dotenv
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
@@ -30,7 +31,7 @@ class Configuration:
load_dotenv()
@staticmethod
def load_config(file_path: str) -> Dict[str, Any]:
def load_config(file_path: str) -> dict[str, Any]:
"""Load server configuration from JSON file.
Args:
@@ -64,12 +65,13 @@ class Configuration:
class Server:
"""Manages MCP server connections and tool execution."""
def __init__(self, name: str, config: Dict[str, Any]) -> None:
def __init__(self, name: str, config: dict[str, Any]) -> None:
self.name: str = name
self.config: Dict[str, Any] = config
self.stdio_context: Optional[Any] = None
self.session: Optional[ClientSession] = None
self.config: dict[str, Any] = config
self.stdio_context: Any | None = None
self.session: ClientSession | None = None
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.exit_stack: AsyncExitStack = AsyncExitStack()
async def initialize(self) -> None:
"""Initialize the server connection."""
@@ -89,17 +91,16 @@ class Server:
else None,
)
try:
self.stdio_context = stdio_client(server_params)
read, write = await self.stdio_context.__aenter__()
self.session = ClientSession(read, write)
await self.session.__aenter__()
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
read, write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(read, write))
await self.session.initialize()
except Exception as e:
logging.error(f"Error initializing server {self.name}: {e}")
await self.cleanup()
raise
async def list_tools(self) -> List[Any]:
async def list_tools(self) -> list[Any]:
"""List available tools from the server.
Returns:
@@ -124,7 +125,7 @@ class Server:
async def execute_tool(
self,
tool_name: str,
arguments: Dict[str, Any],
arguments: dict[str, Any],
retries: int = 2,
delay: float = 1.0,
) -> Any:
@@ -170,29 +171,9 @@ class Server:
"""Clean up server resources."""
async with self._cleanup_lock:
try:
if self.session:
try:
await self.session.__aexit__(None, None, None)
except Exception as e:
logging.warning(
f"Warning during session cleanup for {self.name}: {e}"
)
finally:
self.session = None
if self.stdio_context:
try:
await self.stdio_context.__aexit__(None, None, None)
except (RuntimeError, asyncio.CancelledError) as e:
logging.info(
f"Note: Normal shutdown message for {self.name}: {e}"
)
except Exception as e:
logging.warning(
f"Warning during stdio cleanup for {self.name}: {e}"
)
finally:
self.stdio_context = None
await self.exit_stack.aclose()
self.session = None
self.stdio_context = None
except Exception as e:
logging.error(f"Error during cleanup of server {self.name}: {e}")
@@ -201,11 +182,11 @@ class Tool:
"""Represents a tool with its properties and formatting."""
def __init__(
self, name: str, description: str, input_schema: Dict[str, Any]
self, name: str, description: str, input_schema: dict[str, Any]
) -> None:
self.name: str = name
self.description: str = description
self.input_schema: Dict[str, Any] = input_schema
self.input_schema: dict[str, Any] = input_schema
def format_for_llm(self) -> str:
"""Format tool information for LLM.
@@ -237,7 +218,7 @@ class LLMClient:
def __init__(self, api_key: str) -> None:
self.api_key: str = api_key
def get_response(self, messages: List[Dict[str, str]]) -> str:
def get_response(self, messages: list[dict[str, str]]) -> str:
"""Get a response from the LLM.
Args:
@@ -247,7 +228,7 @@ class LLMClient:
The LLM's response as a string.
Raises:
RequestException: If the request to the LLM fails.
httpx.RequestError: If the request to the LLM fails.
"""
url = "https://api.groq.com/openai/v1/chat/completions"
@@ -266,16 +247,17 @@ class LLMClient:
}
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
return data["choices"][0]["message"]["content"]
with httpx.Client() as client:
response = client.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
return data["choices"][0]["message"]["content"]
except requests.exceptions.RequestException as e:
except httpx.RequestError as e:
error_message = f"Error getting LLM response: {str(e)}"
logging.error(error_message)
if e.response is not None:
if hasattr(e, 'response'):
status_code = e.response.status_code
logging.error(f"Status code: {status_code}")
logging.error(f"Response details: {e.response.text}")
@@ -289,8 +271,8 @@ class LLMClient:
class ChatSession:
"""Orchestrates the interaction between user, LLM, and tools."""
def __init__(self, servers: List[Server], llm_client: LLMClient) -> None:
self.servers: List[Server] = servers
def __init__(self, servers: list[Server], llm_client: LLMClient) -> None:
self.servers: list[Server] = servers
self.llm_client: LLMClient = llm_client
async def cleanup_servers(self) -> None: