Files
dev-gpt/dev_gpt/options/generate/static_files/gateway/custom_gateway.py
Florian Hönicke 1b0c2d5461 🌐 feat: enable cors
2023-05-15 10:54:32 +02:00

155 lines
5.6 KiB
Python

import os
import shutil
import subprocess
from time import sleep
from typing import List, Tuple
import streamlit.web.bootstrap
from jina import Gateway
from jina.serve.runtimes.gateway.composite import CompositeGateway
from streamlit.file_util import get_streamlit_file_path
from streamlit.web.server import Server as StreamlitServer
cur_dir = os.path.dirname(__file__)
def cmd(command, std_output=False, wait=True):
if isinstance(command, str):
command = command.split()
if not std_output:
process = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
else:
process = subprocess.Popen(command)
if wait:
output, error = process.communicate()
return output, error
class PlaygroundGateway(Gateway):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.streamlit_script = 'app.py'
# copy playground/config.toml to streamlit config.toml
streamlit_config_toml_src = os.path.join(cur_dir, 'app_config.toml')
streamlit_config_toml_dest = get_streamlit_file_path("config.toml")
# create streamlit_config_toml_dest if it doesn't exist
os.makedirs(os.path.dirname(streamlit_config_toml_dest), exist_ok=True)
shutil.copyfile(streamlit_config_toml_src, streamlit_config_toml_dest)
async def setup_server(self):
streamlit.web.bootstrap._fix_sys_path(self.streamlit_script)
streamlit.web.bootstrap._fix_matplotlib_crash()
streamlit.web.bootstrap._fix_tornado_crash()
streamlit.web.bootstrap._fix_sys_argv(self.streamlit_script, ())
streamlit.web.bootstrap._fix_pydeck_mapbox_api_warning()
streamlit_cmd = f'streamlit run {self.streamlit_script}'
self.streamlit_server = StreamlitServer(
os.path.join(cur_dir, self.streamlit_script), streamlit_cmd
)
async def run_server(self):
await self.streamlit_server.start()
streamlit.web.bootstrap._on_server_start(self.streamlit_server)
streamlit.web.bootstrap._set_up_signal_handler(self.streamlit_server)
async def shutdown(self):
self.streamlit_server.stop()
class CustomGateway(CompositeGateway):
"""The CustomGateway assumes that the gateway has been started with http on port 8080.
This is the port on which the nginx process listens. After nginx has been started,
it will start the playground on port 8501 and the actual HTTP gateway will start on port 8082.
Nginx is configured to route the requests in the following way:
- /playground -> playground on port 8501
- / -> HTTP gateway on port 8082
"""
def __init__(self, **kwargs):
# need to update port to 8082, as nginx will listen on 8080
http_idx = 0
http_port = kwargs['runtime_args']['port'][http_idx]
if kwargs['runtime_args']['port'][http_idx] != 8080:
raise ValueError(
f'Please, let http port ({http_port}) be 8080 for nginx to work'
)
kwargs['runtime_args']['port'][http_idx] = 8082
kwargs['cors'] = True
super().__init__(**kwargs)
# remove potential clashing arguments from kwargs
kwargs.pop("port", None)
kwargs.pop("protocol", None)
# note order is important
self._add_gateway(
PlaygroundGateway,
8501,
**kwargs,
)
self.setup_nginx()
self.nginx_was_shutdown = False
async def shutdown(self):
await super().shutdown()
if not self.nginx_was_shutdown:
self.shutdown_nginx()
self.nginx_was_shutdown = True
def setup_nginx(self):
command = [
'nginx',
'-c',
os.path.join(cur_dir, '', 'nginx.conf'),
]
output, error = self._run_nginx_command(command)
self.logger.info('Nginx started')
self.logger.info(f'nginx output: {output}')
self.logger.info(f'nginx error: {error}')
def shutdown_nginx(self):
command = ['nginx', '-s', 'stop']
output, error = self._run_nginx_command(command)
self.logger.info('Nginx stopped')
self.logger.info(f'nginx output: {output}')
self.logger.info(f'nginx error: {error}')
def _run_nginx_command(self, command: List[str]) -> Tuple[bytes, bytes]:
self.logger.info(f'Running command: {command}')
output, error = cmd(command)
if error != b'':
# on CI we need to use sudo; using NOW_CI_RUN isn't good if running test locally
self.logger.info(f'nginx error: {error}')
command.insert(0, 'sudo')
self.logger.info(f'So running command: {command}')
output, error = cmd(command)
sleep(10)
return output, error
def _add_gateway(self, gateway_cls, port, protocol='http', **kwargs):
# ignore metrics_registry since it is not copyable
runtime_args = self._deepcopy_with_ignore_attrs(
self.runtime_args,
[
'metrics_registry',
'tracer_provider',
'grpc_tracing_server_interceptors',
'aio_tracing_client_interceptors',
'tracing_client_interceptor',
'monitoring', # disable it for fastapi gateway
],
)
runtime_args.port = [port]
runtime_args.protocol = [protocol]
gateway_kwargs = {k: v for k, v in kwargs.items() if k != 'runtime_args'}
gateway_kwargs['runtime_args'] = dict(vars(runtime_args))
gateway = gateway_cls(**gateway_kwargs)
gateway.streamer = self.streamer
self.gateways.insert(0, gateway)