Inherit environment variables deemed safe by default

This commit is contained in:
Justin Spahr-Summers
2024-11-06 11:05:20 +00:00
parent 60e9c7a0d7
commit 5508697b13

View File

@@ -1,3 +1,4 @@
import os
import sys import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@@ -10,6 +11,34 @@ from pydantic import BaseModel, Field
from mcp_python.types import JSONRPCMessage from mcp_python.types import JSONRPCMessage
# Environment variables to inherit by default
DEFAULT_INHERITED_ENV_VARS = (
["APPDATA", "HOMEDRIVE", "HOMEPATH", "LOCALAPPDATA", "PATH",
"PROCESSOR_ARCHITECTURE", "SYSTEMDRIVE", "SYSTEMROOT", "TEMP",
"USERNAME", "USERPROFILE"]
if sys.platform == "win32"
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
)
def get_default_environment() -> dict[str, str]:
"""Returns a default environment object including only environment variables deemed safe to inherit."""
env: dict[str, str] = {}
for key in DEFAULT_INHERITED_ENV_VARS:
value = os.environ.get(key)
if value is None:
continue
if value.startswith("()"):
# Skip functions, which are a security risk
continue
env[key] = value
return env
class StdioServerParameters(BaseModel): class StdioServerParameters(BaseModel):
command: str command: str
"""The executable to run to start the server.""" """The executable to run to start the server."""
@@ -17,11 +46,11 @@ class StdioServerParameters(BaseModel):
args: list[str] = Field(default_factory=list) args: list[str] = Field(default_factory=list)
"""Command line arguments to pass to the executable.""" """Command line arguments to pass to the executable."""
env: dict[str, str] = Field(default_factory=dict) env: dict[str, str] | None = None
""" """
The environment to use when spawning the process. The environment to use when spawning the process.
The environment is NOT inherited from the parent process by default. If not specified, the result of get_default_environment() will be used.
""" """
@@ -41,7 +70,9 @@ async def stdio_client(server: StdioServerParameters):
write_stream, write_stream_reader = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
process = await anyio.open_process( process = await anyio.open_process(
[server.command, *server.args], env=server.env, stderr=sys.stderr [server.command, *server.args],
env=server.env if server.env is not None else get_default_environment(),
stderr=sys.stderr
) )
async def stdout_reader(): async def stdout_reader():