add text encoding params to STDIO client

This commit is contained in:
TerminalMan
2024-12-29 16:41:30 +00:00
parent e691c511ab
commit 4f36581a5c

View File

@@ -1,6 +1,7 @@
import os import os
import sys import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Literal
import anyio import anyio
import anyio.lowlevel import anyio.lowlevel
@@ -65,6 +66,20 @@ class StdioServerParameters(BaseModel):
If not specified, the result of get_default_environment() will be used. If not specified, the result of get_default_environment() will be used.
""" """
encoding: str = "utf-8"
"""
The text encoding used when sending/receiving messages to the server
defaults to utf-8
"""
encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict"
"""
The text encoding error handler.
See https://docs.python.org/3/library/codecs.html#codec-base-classes for explanations of possible values
"""
@asynccontextmanager @asynccontextmanager
async def stdio_client(server: StdioServerParameters): async def stdio_client(server: StdioServerParameters):
@@ -93,7 +108,7 @@ async def stdio_client(server: StdioServerParameters):
try: try:
async with read_stream_writer: async with read_stream_writer:
buffer = "" buffer = ""
async for chunk in TextReceiveStream(process.stdout): async for chunk in TextReceiveStream(process.stdout, encoding=server.encoding, errors=server.encoding_error_handler):
lines = (buffer + chunk).split("\n") lines = (buffer + chunk).split("\n")
buffer = lines.pop() buffer = lines.pop()
@@ -115,7 +130,7 @@ async def stdio_client(server: StdioServerParameters):
async with write_stream_reader: async with write_stream_reader:
async for message in write_stream_reader: async for message in write_stream_reader:
json = message.model_dump_json(by_alias=True, exclude_none=True) json = message.model_dump_json(by_alias=True, exclude_none=True)
await process.stdin.send((json + "\n").encode()) await process.stdin.send((json + "\n").encode(encoding=server.encoding, errors=server.encoding_error_handler))
except anyio.ClosedResourceError: except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint() await anyio.lowlevel.checkpoint()