Implemented the '--ai-settings' flag

This commit is contained in:
Eesa Hamza
2023-04-13 14:02:42 +03:00
parent 428caa9bef
commit 0f6fba7d65
3 changed files with 30 additions and 1 deletions

View File

@@ -39,6 +39,7 @@ class Config(metaclass=Singleton):
self.speak_mode = False self.speak_mode = False
self.skip_reprompt = False self.skip_reprompt = False
self.ai_settings_file = os.getenv("AI_SETTINGS_FILE", "ai_settings.yaml")
self.fast_llm_model = os.getenv("FAST_LLM_MODEL", "gpt-3.5-turbo") self.fast_llm_model = os.getenv("FAST_LLM_MODEL", "gpt-3.5-turbo")
self.smart_llm_model = os.getenv("SMART_LLM_MODEL", "gpt-4") self.smart_llm_model = os.getenv("SMART_LLM_MODEL", "gpt-4")
self.fast_token_limit = int(os.getenv("FAST_TOKEN_LIMIT", 4000)) self.fast_token_limit = int(os.getenv("FAST_TOKEN_LIMIT", 4000))

View File

@@ -182,7 +182,7 @@ def load_variables(config_file="config.yaml"):
def construct_prompt(): def construct_prompt():
"""Construct the prompt for the AI to respond to""" """Construct the prompt for the AI to respond to"""
config = AIConfig.load() config = AIConfig.load(cfg.ai_settings_file)
if cfg.skip_reprompt and config.ai_name: if cfg.skip_reprompt and config.ai_name:
logger.typewriter_log("Name :", Fore.GREEN, config.ai_name) logger.typewriter_log("Name :", Fore.GREEN, config.ai_name)
logger.typewriter_log("Role :", Fore.GREEN, config.ai_role) logger.typewriter_log("Role :", Fore.GREEN, config.ai_role)
@@ -325,6 +325,20 @@ def parse_arguments():
logger.typewriter_log("Skip Re-prompt: ", Fore.GREEN, "ENABLED") logger.typewriter_log("Skip Re-prompt: ", Fore.GREEN, "ENABLED")
cfg.skip_reprompt = True cfg.skip_reprompt = True
if args.ai_settings_file:
file = args.ai_settings_file
# Validate file
(validated, message) = utils.validate_yaml_file(file)
if not validated:
logger.typewriter_log("FAILED FILE VALIDATION", Fore.RED, message)
exit(1)
logger.typewriter_log("Using AI Settings File:", Fore.GREEN, file)
cfg.ai_settings_file = file
cfg.skip_reprompt = True
# TODO: fill in llm values here # TODO: fill in llm values here
check_openai_api_key() check_openai_api_key()

View File

@@ -1,3 +1,6 @@
import yaml
from colorama import Fore
def clean_input(prompt: str=''): def clean_input(prompt: str=''):
try: try:
return input(prompt) return input(prompt)
@@ -6,3 +9,14 @@ def clean_input(prompt: str=''):
print("Quitting...") print("Quitting...")
exit(0) exit(0)
def validate_yaml_file(file: str):
try:
with open(file) as file:
yaml.load(file, Loader=yaml.FullLoader)
except FileNotFoundError:
return (False, f"The file {Fore.CYAN}`{file}`{Fore.RESET} wasn't found")
except yaml.YAMLError as e:
return (False, f"There was an issue while trying to read with your AI Settings file: {e}")
return (True, f"Successfully validated {Fore.CYAN}`{file}`{Fore.RESET}!")