Merge pull request #120 from SecretiveShell/fix-encoding-errors

add text encoding params to STDIO client
This commit is contained in:
David Soria Parra
2025-01-02 14:59:36 +00:00
committed by GitHub

View File

@@ -1,6 +1,7 @@
import os import os
import sys import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Literal
import anyio import anyio
import anyio.lowlevel import anyio.lowlevel
@@ -65,6 +66,21 @@ class StdioServerParameters(BaseModel):
If not specified, the result of get_default_environment() will be used. If not specified, the result of get_default_environment() will be used.
""" """
encoding: str = "utf-8"
"""
The text encoding used when sending/receiving messages to the server
defaults to utf-8
"""
encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict"
"""
The text encoding error handler.
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
explanations of possible values
"""
@asynccontextmanager @asynccontextmanager
async def stdio_client(server: StdioServerParameters): async def stdio_client(server: StdioServerParameters):
@@ -93,7 +109,11 @@ async def stdio_client(server: StdioServerParameters):
try: try:
async with read_stream_writer: async with read_stream_writer:
buffer = "" buffer = ""
async for chunk in TextReceiveStream(process.stdout): async for chunk in TextReceiveStream(
process.stdout,
encoding=server.encoding,
errors=server.encoding_error_handler,
):
lines = (buffer + chunk).split("\n") lines = (buffer + chunk).split("\n")
buffer = lines.pop() buffer = lines.pop()
@@ -115,7 +135,12 @@ async def stdio_client(server: StdioServerParameters):
async with write_stream_reader: async with write_stream_reader:
async for message in write_stream_reader: async for message in write_stream_reader:
json = message.model_dump_json(by_alias=True, exclude_none=True) json = message.model_dump_json(by_alias=True, exclude_none=True)
await process.stdin.send((json + "\n").encode()) await process.stdin.send(
(json + "\n").encode(
encoding=server.encoding,
errors=server.encoding_error_handler,
)
)
except anyio.ClosedResourceError: except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint() await anyio.lowlevel.checkpoint()