ℹ refactor: better key info (#23)

This commit is contained in:
Florian Hönicke
2023-04-20 11:56:21 +02:00
committed by GitHub
parent 3de94043c4
commit 7a3b8a1eb8
3 changed files with 22 additions and 12 deletions

View File

@@ -19,11 +19,21 @@ from src.options.generate.templates_system import template_system_message_base,
from src.utils.string_tools import print_colored from src.utils.string_tools import print_colored
def configure_openai_api_key():
if 'OPENAI_API_KEY' not in os.environ:
print_colored('You need to set OPENAI_API_KEY in your environment.', '''
Run:
gptdeploy configure --key <your_openai_api_key>
If you have updated it already, please restart your terminal.
''', 'red')
exit(1)
openai.api_key = os.environ['OPENAI_API_KEY']
class GPTSession: class GPTSession:
def __init__(self, task_description, test_description, model: str = 'gpt-4', ): def __init__(self, task_description, test_description, model: str = 'gpt-4', ):
self.task_description = task_description self.task_description = task_description
self.test_description = test_description self.test_description = test_description
self.configure_openai_api_key()
if model == 'gpt-4' and self.is_gpt4_available(): if model == 'gpt-4' and self.is_gpt4_available():
self.pricing_prompt = PRICING_GPT4_PROMPT self.pricing_prompt = PRICING_GPT4_PROMPT
self.pricing_generation = PRICING_GPT4_GENERATION self.pricing_generation = PRICING_GPT4_GENERATION
@@ -42,15 +52,7 @@ class GPTSession:
self.model_name, self.cost_callback, self.task_description, self.test_description, system_definition_examples self.model_name, self.cost_callback, self.task_description, self.test_description, system_definition_examples
) )
@staticmethod
def configure_openai_api_key():
if 'OPENAI_API_KEY' not in os.environ:
raise Exception('''
You need to set OPENAI_API_KEY in your environment.
If you have updated it already, please restart your terminal.
'''
)
openai.api_key = os.environ['OPENAI_API_KEY']
@staticmethod @staticmethod
def is_gpt4_available(): def is_gpt4_available():

View File

@@ -15,6 +15,7 @@ from hubble.executor.helper import upload_file, archive_package, get_request_hea
from jcloud.flow import CloudFlow from jcloud.flow import CloudFlow
from jina import Flow from jina import Flow
from src.apis.gpt import configure_openai_api_key
from src.constants import DEMO_TOKEN from src.constants import DEMO_TOKEN
from src.utils.io import suppress_stdout, is_docker_running from src.utils.io import suppress_stdout, is_docker_running
from src.utils.string_tools import print_colored from src.utils.string_tools import print_colored

View File

@@ -3,9 +3,15 @@ import os
import click import click
from src.apis.gpt import configure_openai_api_key
from src.apis.jina_cloud import jina_auth_login from src.apis.jina_cloud import jina_auth_login
from src.options.configure.key_handling import set_api_key from src.options.configure.key_handling import set_api_key
def openai_api_key_needed(func):
def wrapper(*args, **kwargs):
configure_openai_api_key()
return func(*args, **kwargs)
return wrapper
def exception_interceptor(func): def exception_interceptor(func):
@functools.wraps(func) @functools.wraps(func)
@@ -41,6 +47,7 @@ def main(ctx):
click.echo(ctx.get_help()) click.echo(ctx.get_help())
@openai_api_key_needed
@main.command() @main.command()
@click.option('--description', required=True, help='Description of the microservice.') @click.option('--description', required=True, help='Description of the microservice.')
@click.option('--test', required=True, help='Test scenario for the microservice.') @click.option('--test', required=True, help='Test scenario for the microservice.')
@@ -66,7 +73,7 @@ def generate(
generator = Generator(description, test, model=model) generator = Generator(description, test, model=model)
generator.generate(path) generator.generate(path)
@openai_api_key_needed
@main.command() @main.command()
@path_param @path_param
def run(path): def run(path):
@@ -75,7 +82,7 @@ def run(path):
path = os.path.abspath(path) path = os.path.abspath(path)
Runner().run(path) Runner().run(path)
@openai_api_key_needed
@main.command() @main.command()
@path_param @path_param
def deploy(path): def deploy(path):