refactor: cleanup

This commit is contained in:
Florian Hönicke
2023-04-04 14:02:29 +02:00
parent 0fc213c807
commit de5f9cc2a8
6 changed files with 95 additions and 72 deletions

View File

@@ -1,6 +1,6 @@
import os
from time import sleep
from typing import Union, List, Tuple
from typing import List, Tuple
import openai
from openai.error import RateLimitError, Timeout
@@ -9,13 +9,39 @@ from src.prompt_system import system_base_definition
from src.utils.io import timeout_generator_wrapper, GenerationTimeoutError
from src.utils.string_tools import print_colored
PRICING_GPT4_PROMPT = 0.03
PRICING_GPT4_GENERATION = 0.06
PRICING_GPT3_5_TURBO_PROMPT = 0.002
PRICING_GPT3_5_TURBO_GENERATION = 0.002
if 'OPENAI_API_KEY' not in os.environ:
raise Exception('You need to set OPENAI_API_KEY in your environment')
openai.api_key = os.environ['OPENAI_API_KEY']
try:
openai.ChatCompletion.create(
model="gpt-4",
messages=[{
"role": 'system',
"content": 'test'
}]
)
supported_model = 'gpt-4'
pricing_prompt = PRICING_GPT4_PROMPT
pricing_generation = PRICING_GPT4_GENERATION
except openai.error.InvalidRequestError:
supported_model = 'gpt-3.5-turbo'
pricing_prompt = PRICING_GPT3_5_TURBO_PROMPT
pricing_generation = PRICING_GPT3_5_TURBO_GENERATION
total_chars_prompt = 0
total_chars_generation = 0
class Conversation:
def __init__(self, prompt_list: List[Tuple[str, str]] = None):
def __init__(self, prompt_list: List[Tuple[str, str]] = None, model=supported_model):
self.model = model
if prompt_list is None:
prompt_list = [('system', system_base_definition)]
self.prompt_list = prompt_list
@@ -24,49 +50,48 @@ class Conversation:
def query(self, prompt: str):
print_colored('user', prompt, 'blue')
self.prompt_list.append(('user', prompt))
response = get_response(self.prompt_list)
response = self.get_response(self.prompt_list)
self.prompt_list.append(('assistant', response))
return response
def get_response(prompt_list: List[Tuple[str, str]]):
global total_chars_prompt, total_chars_generation
for i in range(10):
try:
response_generator = openai.ChatCompletion.create(
temperature=0,
max_tokens=2_000,
model="gpt-4",
stream=True,
messages=[
{
"role": prompt[0],
"content": prompt[1]
}
for prompt in prompt_list
]
)
response_generator_with_timeout = timeout_generator_wrapper(response_generator, 10)
total_chars_prompt += sum(len(prompt[1]) for prompt in prompt_list)
complete_string = ''
for chunk in response_generator_with_timeout:
delta = chunk['choices'][0]['delta']
if 'content' in delta:
content = delta['content']
print_colored('' if complete_string else 'assistent', content, 'green', end='')
complete_string += content
total_chars_generation += len(content)
print('\n')
money_prompt = round(total_chars_prompt / 3.4 * 0.03 / 1000, 2)
money_generation = round(total_chars_generation / 3.4 * 0.06 / 1000, 2)
print('money prompt:', f'${money_prompt}')
print('money generation:', f'${money_generation}')
print('total money:', f'${money_prompt + money_generation}')
print('\n')
return complete_string
except (RateLimitError, Timeout, ConnectionError, GenerationTimeoutError) as e:
print(e)
print('retrying')
sleep(3)
continue
raise Exception('Failed to get response')
def get_response(self, prompt_list: List[Tuple[str, str]]):
global total_chars_prompt, total_chars_generation
for i in range(10):
try:
response_generator = openai.ChatCompletion.create(
temperature=0,
max_tokens=2_000,
model=self.model,
stream=True,
messages=[
{
"role": prompt[0],
"content": prompt[1]
}
for prompt in prompt_list
]
)
response_generator_with_timeout = timeout_generator_wrapper(response_generator, 10)
total_chars_prompt += sum(len(prompt[1]) for prompt in prompt_list)
complete_string = ''
for chunk in response_generator_with_timeout:
delta = chunk['choices'][0]['delta']
if 'content' in delta:
content = delta['content']
print_colored('' if complete_string else 'assistent', content, 'green', end='')
complete_string += content
total_chars_generation += len(content)
print('\n')
money_prompt = round(total_chars_prompt / 3.4 * pricing_prompt / 1000, 2)
money_generation = round(total_chars_generation / 3.4 * pricing_generation / 1000, 2)
print('money prompt:', f'${money_prompt}')
print('money generation:', f'${money_generation}')
print('total money:', f'${money_prompt + money_generation}')
print('\n')
return complete_string
except (RateLimitError, Timeout, ConnectionError, GenerationTimeoutError) as e:
print(e)
print('retrying')
sleep(3)
continue
raise Exception('Failed to get response')

54
src/server.py Normal file
View File

@@ -0,0 +1,54 @@
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from jina import Flow
from pydantic import BaseModel, HttpUrl
from typing import Optional, Dict
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from main import main
app = FastAPI()
# Define the request model
class CreateRequest(BaseModel):
test_scenario: str
executor_description: str
# Define the response model
class CreateResponse(BaseModel):
result: Dict[str, str]
success: bool
message: Optional[str]
@app.post("/create", response_model=CreateResponse)
def create_endpoint(request: CreateRequest):
result = main(
executor_description=request.executor_description,
test_scenario=request.test_scenario,
)
return CreateResponse(result=result, success=True, message=None)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Add a custom exception handler for RequestValidationError
@app.exception_handler(RequestValidationError)
def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=422,
content={"detail": exc.errors()},
)
if __name__ == "__main__":
import uvicorn
uvicorn.run("server:app", host="0.0.0.0", port=8000, log_level="info")