mirror of
https://github.com/aljazceru/dev-gpt.git
synced 2025-12-19 14:44:20 +01:00
155 lines
5.6 KiB
Python
155 lines
5.6 KiB
Python
import os
|
|
import shutil
|
|
import subprocess
|
|
from time import sleep
|
|
from typing import List, Tuple
|
|
|
|
import streamlit.web.bootstrap
|
|
from jina import Gateway
|
|
from jina.serve.runtimes.gateway.composite import CompositeGateway
|
|
from streamlit.file_util import get_streamlit_file_path
|
|
from streamlit.web.server import Server as StreamlitServer
|
|
|
|
|
|
cur_dir = os.path.dirname(__file__)
|
|
|
|
|
|
def cmd(command, std_output=False, wait=True):
|
|
if isinstance(command, str):
|
|
command = command.split()
|
|
if not std_output:
|
|
process = subprocess.Popen(
|
|
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
|
)
|
|
else:
|
|
process = subprocess.Popen(command)
|
|
if wait:
|
|
output, error = process.communicate()
|
|
return output, error
|
|
|
|
|
|
class PlaygroundGateway(Gateway):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.streamlit_script = 'app.py'
|
|
# copy playground/config.toml to streamlit config.toml
|
|
streamlit_config_toml_src = os.path.join(cur_dir, 'app_config.toml')
|
|
streamlit_config_toml_dest = get_streamlit_file_path("config.toml")
|
|
# create streamlit_config_toml_dest if it doesn't exist
|
|
os.makedirs(os.path.dirname(streamlit_config_toml_dest), exist_ok=True)
|
|
shutil.copyfile(streamlit_config_toml_src, streamlit_config_toml_dest)
|
|
|
|
async def setup_server(self):
|
|
streamlit.web.bootstrap._fix_sys_path(self.streamlit_script)
|
|
streamlit.web.bootstrap._fix_matplotlib_crash()
|
|
streamlit.web.bootstrap._fix_tornado_crash()
|
|
streamlit.web.bootstrap._fix_sys_argv(self.streamlit_script, ())
|
|
streamlit.web.bootstrap._fix_pydeck_mapbox_api_warning()
|
|
streamlit_cmd = f'streamlit run {self.streamlit_script}'
|
|
|
|
self.streamlit_server = StreamlitServer(
|
|
os.path.join(cur_dir, self.streamlit_script), streamlit_cmd
|
|
)
|
|
|
|
async def run_server(self):
|
|
await self.streamlit_server.start()
|
|
streamlit.web.bootstrap._on_server_start(self.streamlit_server)
|
|
streamlit.web.bootstrap._set_up_signal_handler(self.streamlit_server)
|
|
|
|
async def shutdown(self):
|
|
self.streamlit_server.stop()
|
|
|
|
|
|
class CustomGateway(CompositeGateway):
|
|
"""The CustomGateway assumes that the gateway has been started with http on port 8080.
|
|
This is the port on which the nginx process listens. After nginx has been started,
|
|
it will start the playground on port 8501 and the actual HTTP gateway will start on port 8082.
|
|
|
|
Nginx is configured to route the requests in the following way:
|
|
- /playground -> playground on port 8501
|
|
- / -> HTTP gateway on port 8082
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
# need to update port to 8082, as nginx will listen on 8080
|
|
http_idx = 0
|
|
http_port = kwargs['runtime_args']['port'][http_idx]
|
|
if kwargs['runtime_args']['port'][http_idx] != 8080:
|
|
raise ValueError(
|
|
f'Please, let http port ({http_port}) be 8080 for nginx to work'
|
|
)
|
|
kwargs['runtime_args']['port'][http_idx] = 8082
|
|
kwargs['cors'] = True
|
|
super().__init__(**kwargs)
|
|
|
|
# remove potential clashing arguments from kwargs
|
|
kwargs.pop("port", None)
|
|
kwargs.pop("protocol", None)
|
|
|
|
# note order is important
|
|
self._add_gateway(
|
|
PlaygroundGateway,
|
|
8501,
|
|
**kwargs,
|
|
)
|
|
|
|
self.setup_nginx()
|
|
self.nginx_was_shutdown = False
|
|
|
|
async def shutdown(self):
|
|
await super().shutdown()
|
|
if not self.nginx_was_shutdown:
|
|
self.shutdown_nginx()
|
|
self.nginx_was_shutdown = True
|
|
|
|
def setup_nginx(self):
|
|
command = [
|
|
'nginx',
|
|
'-c',
|
|
os.path.join(cur_dir, '', 'nginx.conf'),
|
|
]
|
|
output, error = self._run_nginx_command(command)
|
|
self.logger.info('Nginx started')
|
|
self.logger.info(f'nginx output: {output}')
|
|
self.logger.info(f'nginx error: {error}')
|
|
|
|
def shutdown_nginx(self):
|
|
command = ['nginx', '-s', 'stop']
|
|
output, error = self._run_nginx_command(command)
|
|
self.logger.info('Nginx stopped')
|
|
self.logger.info(f'nginx output: {output}')
|
|
self.logger.info(f'nginx error: {error}')
|
|
|
|
def _run_nginx_command(self, command: List[str]) -> Tuple[bytes, bytes]:
|
|
self.logger.info(f'Running command: {command}')
|
|
output, error = cmd(command)
|
|
if error != b'':
|
|
# on CI we need to use sudo; using NOW_CI_RUN isn't good if running test locally
|
|
self.logger.info(f'nginx error: {error}')
|
|
command.insert(0, 'sudo')
|
|
self.logger.info(f'So running command: {command}')
|
|
output, error = cmd(command)
|
|
sleep(10)
|
|
return output, error
|
|
|
|
def _add_gateway(self, gateway_cls, port, protocol='http', **kwargs):
|
|
# ignore metrics_registry since it is not copyable
|
|
runtime_args = self._deepcopy_with_ignore_attrs(
|
|
self.runtime_args,
|
|
[
|
|
'metrics_registry',
|
|
'tracer_provider',
|
|
'grpc_tracing_server_interceptors',
|
|
'aio_tracing_client_interceptors',
|
|
'tracing_client_interceptor',
|
|
'monitoring', # disable it for fastapi gateway
|
|
],
|
|
)
|
|
runtime_args.port = [port]
|
|
runtime_args.protocol = [protocol]
|
|
gateway_kwargs = {k: v for k, v in kwargs.items() if k != 'runtime_args'}
|
|
gateway_kwargs['runtime_args'] = dict(vars(runtime_args))
|
|
gateway = gateway_cls(**gateway_kwargs)
|
|
gateway.streamer = self.streamer
|
|
self.gateways.insert(0, gateway)
|