mirror of
https://github.com/aljazceru/dev-gpt.git
synced 2025-12-19 22:54:21 +01:00
🪓 feat: sub task refinement
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from time import sleep
|
||||
@@ -17,7 +18,7 @@ from urllib3.exceptions import InvalidChunkLength
|
||||
from dev_gpt.constants import PRICING_GPT4_PROMPT, PRICING_GPT4_GENERATION, PRICING_GPT3_5_TURBO_PROMPT, \
|
||||
PRICING_GPT3_5_TURBO_GENERATION, CHARS_PER_TOKEN
|
||||
from dev_gpt.options.generate.templates_system import template_system_message_base
|
||||
from dev_gpt.utils.string_tools import print_colored
|
||||
from dev_gpt.utils.string_tools import print_colored, get_template_parameters
|
||||
|
||||
|
||||
def configure_openai_api_key():
|
||||
@@ -32,8 +33,17 @@ If you have updated it already, please restart your terminal.
|
||||
openai.api_key = os.environ['OPENAI_API_KEY']
|
||||
|
||||
class GPTSession:
|
||||
def __init__(self, task_description, model: str = 'gpt-4', ):
|
||||
self.task_description = task_description
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(GPTSession, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, model: str = 'gpt-4', ):
|
||||
if GPTSession._initialized:
|
||||
return
|
||||
if model == 'gpt-4' and self.is_gpt4_available():
|
||||
self.pricing_prompt = PRICING_GPT4_PROMPT
|
||||
self.pricing_generation = PRICING_GPT4_GENERATION
|
||||
@@ -46,6 +56,7 @@ class GPTSession:
|
||||
self.model_name = model
|
||||
self.chars_prompt_so_far = 0
|
||||
self.chars_generation_so_far = 0
|
||||
GPTSession._initialized = True
|
||||
|
||||
def get_conversation(self, messages: List[BaseMessage] = [], print_stream: bool = True, print_costs: bool = True):
|
||||
messages = deepcopy(messages)
|
||||
@@ -151,3 +162,22 @@ class _GPTConversation:
|
||||
test_description=test_description,
|
||||
)
|
||||
return SystemMessage(content=system_message)
|
||||
|
||||
|
||||
def ask_gpt(prompt_template, parser, **kwargs):
|
||||
template_parameters = get_template_parameters(prompt_template)
|
||||
if set(template_parameters) != set(kwargs.keys()):
|
||||
raise ValueError(f'Prompt template parameters {get_template_parameters(prompt_template)} do not match '
|
||||
f'provided parameters {kwargs.keys()}')
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, dict):
|
||||
kwargs[key] = json.dumps(value, indent=4)
|
||||
prompt = prompt_template.format(**kwargs)
|
||||
conversation = GPTSession().get_conversation(
|
||||
[],
|
||||
print_stream=os.environ['VERBOSE'].lower() == 'true',
|
||||
print_costs=False
|
||||
)
|
||||
agent_response_raw = conversation.chat(prompt, role='user')
|
||||
agent_response = parser(agent_response_raw)
|
||||
return agent_response
|
||||
|
||||
Reference in New Issue
Block a user