mirror of
https://github.com/aljazceru/mcp-python-sdk.git
synced 2025-12-18 22:44:20 +01:00
Co-authored-by: David Soria Parra <davidsp@anthropic.com> Co-authored-by: Basil Hosmer <basil@anthropic.com> Co-authored-by: ihrpr <inna@anthropic.com>
431 lines
14 KiB
Python
431 lines
14 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import shutil
|
|
from contextlib import AsyncExitStack
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from dotenv import load_dotenv
|
|
from mcp import ClientSession, StdioServerParameters
|
|
from mcp.client.stdio import stdio_client
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
)
|
|
|
|
|
|
class Configuration:
|
|
"""Manages configuration and environment variables for the MCP client."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize configuration with environment variables."""
|
|
self.load_env()
|
|
self.api_key = os.getenv("LLM_API_KEY")
|
|
|
|
@staticmethod
|
|
def load_env() -> None:
|
|
"""Load environment variables from .env file."""
|
|
load_dotenv()
|
|
|
|
@staticmethod
|
|
def load_config(file_path: str) -> dict[str, Any]:
|
|
"""Load server configuration from JSON file.
|
|
|
|
Args:
|
|
file_path: Path to the JSON configuration file.
|
|
|
|
Returns:
|
|
Dict containing server configuration.
|
|
|
|
Raises:
|
|
FileNotFoundError: If configuration file doesn't exist.
|
|
JSONDecodeError: If configuration file is invalid JSON.
|
|
"""
|
|
with open(file_path, "r") as f:
|
|
return json.load(f)
|
|
|
|
@property
|
|
def llm_api_key(self) -> str:
|
|
"""Get the LLM API key.
|
|
|
|
Returns:
|
|
The API key as a string.
|
|
|
|
Raises:
|
|
ValueError: If the API key is not found in environment variables.
|
|
"""
|
|
if not self.api_key:
|
|
raise ValueError("LLM_API_KEY not found in environment variables")
|
|
return self.api_key
|
|
|
|
|
|
class Server:
|
|
"""Manages MCP server connections and tool execution."""
|
|
|
|
def __init__(self, name: str, config: dict[str, Any]) -> None:
|
|
self.name: str = name
|
|
self.config: dict[str, Any] = config
|
|
self.stdio_context: Any | None = None
|
|
self.session: ClientSession | None = None
|
|
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
|
|
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize the server connection."""
|
|
command = (
|
|
shutil.which("npx")
|
|
if self.config["command"] == "npx"
|
|
else self.config["command"]
|
|
)
|
|
if command is None:
|
|
raise ValueError("The command must be a valid string and cannot be None.")
|
|
|
|
server_params = StdioServerParameters(
|
|
command=command,
|
|
args=self.config["args"],
|
|
env={**os.environ, **self.config["env"]}
|
|
if self.config.get("env")
|
|
else None,
|
|
)
|
|
try:
|
|
stdio_transport = await self.exit_stack.enter_async_context(
|
|
stdio_client(server_params)
|
|
)
|
|
read, write = stdio_transport
|
|
session = await self.exit_stack.enter_async_context(
|
|
ClientSession(read, write)
|
|
)
|
|
await session.initialize()
|
|
self.session = session
|
|
except Exception as e:
|
|
logging.error(f"Error initializing server {self.name}: {e}")
|
|
await self.cleanup()
|
|
raise
|
|
|
|
async def list_tools(self) -> list[Any]:
|
|
"""List available tools from the server.
|
|
|
|
Returns:
|
|
A list of available tools.
|
|
|
|
Raises:
|
|
RuntimeError: If the server is not initialized.
|
|
"""
|
|
if not self.session:
|
|
raise RuntimeError(f"Server {self.name} not initialized")
|
|
|
|
tools_response = await self.session.list_tools()
|
|
tools = []
|
|
|
|
for item in tools_response:
|
|
if isinstance(item, tuple) and item[0] == "tools":
|
|
tools.extend(
|
|
Tool(tool.name, tool.description, tool.inputSchema)
|
|
for tool in item[1]
|
|
)
|
|
|
|
return tools
|
|
|
|
async def execute_tool(
|
|
self,
|
|
tool_name: str,
|
|
arguments: dict[str, Any],
|
|
retries: int = 2,
|
|
delay: float = 1.0,
|
|
) -> Any:
|
|
"""Execute a tool with retry mechanism.
|
|
|
|
Args:
|
|
tool_name: Name of the tool to execute.
|
|
arguments: Tool arguments.
|
|
retries: Number of retry attempts.
|
|
delay: Delay between retries in seconds.
|
|
|
|
Returns:
|
|
Tool execution result.
|
|
|
|
Raises:
|
|
RuntimeError: If server is not initialized.
|
|
Exception: If tool execution fails after all retries.
|
|
"""
|
|
if not self.session:
|
|
raise RuntimeError(f"Server {self.name} not initialized")
|
|
|
|
attempt = 0
|
|
while attempt < retries:
|
|
try:
|
|
logging.info(f"Executing {tool_name}...")
|
|
result = await self.session.call_tool(tool_name, arguments)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
attempt += 1
|
|
logging.warning(
|
|
f"Error executing tool: {e}. Attempt {attempt} of {retries}."
|
|
)
|
|
if attempt < retries:
|
|
logging.info(f"Retrying in {delay} seconds...")
|
|
await asyncio.sleep(delay)
|
|
else:
|
|
logging.error("Max retries reached. Failing.")
|
|
raise
|
|
|
|
async def cleanup(self) -> None:
|
|
"""Clean up server resources."""
|
|
async with self._cleanup_lock:
|
|
try:
|
|
await self.exit_stack.aclose()
|
|
self.session = None
|
|
self.stdio_context = None
|
|
except Exception as e:
|
|
logging.error(f"Error during cleanup of server {self.name}: {e}")
|
|
|
|
|
|
class Tool:
|
|
"""Represents a tool with its properties and formatting."""
|
|
|
|
def __init__(
|
|
self, name: str, description: str, input_schema: dict[str, Any]
|
|
) -> None:
|
|
self.name: str = name
|
|
self.description: str = description
|
|
self.input_schema: dict[str, Any] = input_schema
|
|
|
|
def format_for_llm(self) -> str:
|
|
"""Format tool information for LLM.
|
|
|
|
Returns:
|
|
A formatted string describing the tool.
|
|
"""
|
|
args_desc = []
|
|
if "properties" in self.input_schema:
|
|
for param_name, param_info in self.input_schema["properties"].items():
|
|
arg_desc = (
|
|
f"- {param_name}: {param_info.get('description', 'No description')}"
|
|
)
|
|
if param_name in self.input_schema.get("required", []):
|
|
arg_desc += " (required)"
|
|
args_desc.append(arg_desc)
|
|
|
|
return f"""
|
|
Tool: {self.name}
|
|
Description: {self.description}
|
|
Arguments:
|
|
{chr(10).join(args_desc)}
|
|
"""
|
|
|
|
|
|
class LLMClient:
|
|
"""Manages communication with the LLM provider."""
|
|
|
|
def __init__(self, api_key: str) -> None:
|
|
self.api_key: str = api_key
|
|
|
|
def get_response(self, messages: list[dict[str, str]]) -> str:
|
|
"""Get a response from the LLM.
|
|
|
|
Args:
|
|
messages: A list of message dictionaries.
|
|
|
|
Returns:
|
|
The LLM's response as a string.
|
|
|
|
Raises:
|
|
httpx.RequestError: If the request to the LLM fails.
|
|
"""
|
|
url = "https://api.groq.com/openai/v1/chat/completions"
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
}
|
|
payload = {
|
|
"messages": messages,
|
|
"model": "llama-3.2-90b-vision-preview",
|
|
"temperature": 0.7,
|
|
"max_tokens": 4096,
|
|
"top_p": 1,
|
|
"stream": False,
|
|
"stop": None,
|
|
}
|
|
|
|
try:
|
|
with httpx.Client() as client:
|
|
response = client.post(url, headers=headers, json=payload)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data["choices"][0]["message"]["content"]
|
|
|
|
except httpx.RequestError as e:
|
|
error_message = f"Error getting LLM response: {str(e)}"
|
|
logging.error(error_message)
|
|
|
|
if isinstance(e, httpx.HTTPStatusError):
|
|
status_code = e.response.status_code
|
|
logging.error(f"Status code: {status_code}")
|
|
logging.error(f"Response details: {e.response.text}")
|
|
|
|
return (
|
|
f"I encountered an error: {error_message}. "
|
|
"Please try again or rephrase your request."
|
|
)
|
|
|
|
|
|
class ChatSession:
|
|
"""Orchestrates the interaction between user, LLM, and tools."""
|
|
|
|
def __init__(self, servers: list[Server], llm_client: LLMClient) -> None:
|
|
self.servers: list[Server] = servers
|
|
self.llm_client: LLMClient = llm_client
|
|
|
|
async def cleanup_servers(self) -> None:
|
|
"""Clean up all servers properly."""
|
|
cleanup_tasks = [
|
|
asyncio.create_task(server.cleanup()) for server in self.servers
|
|
]
|
|
if cleanup_tasks:
|
|
try:
|
|
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
|
|
except Exception as e:
|
|
logging.warning(f"Warning during final cleanup: {e}")
|
|
|
|
async def process_llm_response(self, llm_response: str) -> str:
|
|
"""Process the LLM response and execute tools if needed.
|
|
|
|
Args:
|
|
llm_response: The response from the LLM.
|
|
|
|
Returns:
|
|
The result of tool execution or the original response.
|
|
"""
|
|
import json
|
|
|
|
try:
|
|
tool_call = json.loads(llm_response)
|
|
if "tool" in tool_call and "arguments" in tool_call:
|
|
logging.info(f"Executing tool: {tool_call['tool']}")
|
|
logging.info(f"With arguments: {tool_call['arguments']}")
|
|
|
|
for server in self.servers:
|
|
tools = await server.list_tools()
|
|
if any(tool.name == tool_call["tool"] for tool in tools):
|
|
try:
|
|
result = await server.execute_tool(
|
|
tool_call["tool"], tool_call["arguments"]
|
|
)
|
|
|
|
if isinstance(result, dict) and "progress" in result:
|
|
progress = result["progress"]
|
|
total = result["total"]
|
|
percentage = (progress / total) * 100
|
|
logging.info(
|
|
f"Progress: {progress}/{total} ({percentage:.1f}%)"
|
|
)
|
|
|
|
return f"Tool execution result: {result}"
|
|
except Exception as e:
|
|
error_msg = f"Error executing tool: {str(e)}"
|
|
logging.error(error_msg)
|
|
return error_msg
|
|
|
|
return f"No server found with tool: {tool_call['tool']}"
|
|
return llm_response
|
|
except json.JSONDecodeError:
|
|
return llm_response
|
|
|
|
async def start(self) -> None:
|
|
"""Main chat session handler."""
|
|
try:
|
|
for server in self.servers:
|
|
try:
|
|
await server.initialize()
|
|
except Exception as e:
|
|
logging.error(f"Failed to initialize server: {e}")
|
|
await self.cleanup_servers()
|
|
return
|
|
|
|
all_tools = []
|
|
for server in self.servers:
|
|
tools = await server.list_tools()
|
|
all_tools.extend(tools)
|
|
|
|
tools_description = "\n".join([tool.format_for_llm() for tool in all_tools])
|
|
|
|
system_message = (
|
|
"You are a helpful assistant with access to these tools:\n\n"
|
|
f"{tools_description}\n"
|
|
"Choose the appropriate tool based on the user's question. "
|
|
"If no tool is needed, reply directly.\n\n"
|
|
"IMPORTANT: When you need to use a tool, you must ONLY respond with "
|
|
"the exact JSON object format below, nothing else:\n"
|
|
"{\n"
|
|
' "tool": "tool-name",\n'
|
|
' "arguments": {\n'
|
|
' "argument-name": "value"\n'
|
|
" }\n"
|
|
"}\n\n"
|
|
"After receiving a tool's response:\n"
|
|
"1. Transform the raw data into a natural, conversational response\n"
|
|
"2. Keep responses concise but informative\n"
|
|
"3. Focus on the most relevant information\n"
|
|
"4. Use appropriate context from the user's question\n"
|
|
"5. Avoid simply repeating the raw data\n\n"
|
|
"Please use only the tools that are explicitly defined above."
|
|
)
|
|
|
|
messages = [{"role": "system", "content": system_message}]
|
|
|
|
while True:
|
|
try:
|
|
user_input = input("You: ").strip().lower()
|
|
if user_input in ["quit", "exit"]:
|
|
logging.info("\nExiting...")
|
|
break
|
|
|
|
messages.append({"role": "user", "content": user_input})
|
|
|
|
llm_response = self.llm_client.get_response(messages)
|
|
logging.info("\nAssistant: %s", llm_response)
|
|
|
|
result = await self.process_llm_response(llm_response)
|
|
|
|
if result != llm_response:
|
|
messages.append({"role": "assistant", "content": llm_response})
|
|
messages.append({"role": "system", "content": result})
|
|
|
|
final_response = self.llm_client.get_response(messages)
|
|
logging.info("\nFinal response: %s", final_response)
|
|
messages.append(
|
|
{"role": "assistant", "content": final_response}
|
|
)
|
|
else:
|
|
messages.append({"role": "assistant", "content": llm_response})
|
|
|
|
except KeyboardInterrupt:
|
|
logging.info("\nExiting...")
|
|
break
|
|
|
|
finally:
|
|
await self.cleanup_servers()
|
|
|
|
|
|
async def main() -> None:
|
|
"""Initialize and run the chat session."""
|
|
config = Configuration()
|
|
server_config = config.load_config("servers_config.json")
|
|
servers = [
|
|
Server(name, srv_config)
|
|
for name, srv_config in server_config["mcpServers"].items()
|
|
]
|
|
llm_client = LLMClient(config.llm_api_key)
|
|
chat_session = ChatSession(servers, llm_client)
|
|
await chat_session.start()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|