mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-23 08:54:24 +01:00
Pass Configs to Commands and remove CFG = Config() in the commands/ folder (#4328)
* feat: pass config to call_ai_functions in coimmands * feat: config for read_audio_from_file * feat: file operations cfg NOTE: we replaced the CFG in the command enable with TRUE b/c not sure how to handle this yet * feat: git command conversion * feat: google search * feat: image generation * feat: extract cfg from browser commands * feat: remove cfg from execute code commands * fix: file operation related tests * fix: linting * fix: tests for read_audio * fix: test error * feat: update cassettes * fix: linting * fix: test typechecking * fix: google_search errors if unexpected kw arg is passed * fix: pass config param to google search test * fix: agent commands were broken + cassettes * fix: agent test * feat: cassettes * feat: enable/disable logic for commands * fix: some commands threw errors * feat: fix tests * Add new cassettes * Add new cassettes * ci: trigger ci * Update autogpt/commands/execute_code.py Co-authored-by: Reinier van der Leer <github@pwuts.nl> * fix prompt * fix prompt + rebase * add config remove useless imports * put back CFG just for download file * lint * The signature should be mandatory in the decorator * black isort * fix: remove the CFG * fix: non typed arg * lint: type some args * lint: add types for libraries * Add new cassettes * fix: windows compatibility * fix: add config access to decorator * fix: remove twitter mention * DDGS search works at 3.0.2 version * ci: linting --------- Co-authored-by: Auto-GPT-Bot <github-bot@agpt.co> Co-authored-by: merwanehamadi <merwanehamadi@gmail.com> Co-authored-by: Reinier van der Leer <github@pwuts.nl> Co-authored-by: kinance <kinance@gmail.com>
This commit is contained in:
@@ -4,6 +4,7 @@ import json
|
||||
import time
|
||||
import uuid
|
||||
from base64 import b64decode
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import openai
|
||||
import requests
|
||||
@@ -13,11 +14,18 @@ from autogpt.commands.command import command
|
||||
from autogpt.config import Config
|
||||
from autogpt.logs import logger
|
||||
|
||||
CFG = Config()
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
|
||||
|
||||
@command("generate_image", "Generate Image", '"prompt": "<prompt>"', CFG.image_provider)
|
||||
def generate_image(prompt: str, size: int = 256) -> str:
|
||||
@command(
|
||||
"generate_image",
|
||||
"Generate Image",
|
||||
'"prompt": "<prompt>"',
|
||||
lambda config: config.image_provider,
|
||||
"Requires a image provider to be set.",
|
||||
)
|
||||
def generate_image(prompt: str, config: Config, size: int = 256) -> str:
|
||||
"""Generate an image from a prompt.
|
||||
|
||||
Args:
|
||||
@@ -27,21 +35,21 @@ def generate_image(prompt: str, size: int = 256) -> str:
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
filename = f"{CFG.workspace_path}/{str(uuid.uuid4())}.jpg"
|
||||
filename = f"{config.workspace_path}/{str(uuid.uuid4())}.jpg"
|
||||
|
||||
# DALL-E
|
||||
if CFG.image_provider == "dalle":
|
||||
return generate_image_with_dalle(prompt, filename, size)
|
||||
if config.image_provider == "dalle":
|
||||
return generate_image_with_dalle(prompt, filename, size, config)
|
||||
# HuggingFace
|
||||
elif CFG.image_provider == "huggingface":
|
||||
return generate_image_with_hf(prompt, filename)
|
||||
elif config.image_provider == "huggingface":
|
||||
return generate_image_with_hf(prompt, filename, config)
|
||||
# SD WebUI
|
||||
elif CFG.image_provider == "sdwebui":
|
||||
return generate_image_with_sd_webui(prompt, filename, size)
|
||||
elif config.image_provider == "sdwebui":
|
||||
return generate_image_with_sd_webui(prompt, filename, config, size)
|
||||
return "No Image Provider Set"
|
||||
|
||||
|
||||
def generate_image_with_hf(prompt: str, filename: str) -> str:
|
||||
def generate_image_with_hf(prompt: str, filename: str, config: Config) -> str:
|
||||
"""Generate an image with HuggingFace's API.
|
||||
|
||||
Args:
|
||||
@@ -52,14 +60,14 @@ def generate_image_with_hf(prompt: str, filename: str) -> str:
|
||||
str: The filename of the image
|
||||
"""
|
||||
API_URL = (
|
||||
f"https://api-inference.huggingface.co/models/{CFG.huggingface_image_model}"
|
||||
f"https://api-inference.huggingface.co/models/{config.huggingface_image_model}"
|
||||
)
|
||||
if CFG.huggingface_api_token is None:
|
||||
if config.huggingface_api_token is None:
|
||||
raise ValueError(
|
||||
"You need to set your Hugging Face API token in the config file."
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {CFG.huggingface_api_token}",
|
||||
"Authorization": f"Bearer {config.huggingface_api_token}",
|
||||
"X-Use-Cache": "false",
|
||||
}
|
||||
|
||||
@@ -101,7 +109,9 @@ def generate_image_with_hf(prompt: str, filename: str) -> str:
|
||||
return f"Error creating image."
|
||||
|
||||
|
||||
def generate_image_with_dalle(prompt: str, filename: str, size: int) -> str:
|
||||
def generate_image_with_dalle(
|
||||
prompt: str, filename: str, size: int, config: Config
|
||||
) -> str:
|
||||
"""Generate an image with DALL-E.
|
||||
|
||||
Args:
|
||||
@@ -126,7 +136,7 @@ def generate_image_with_dalle(prompt: str, filename: str, size: int) -> str:
|
||||
n=1,
|
||||
size=f"{size}x{size}",
|
||||
response_format="b64_json",
|
||||
api_key=CFG.openai_api_key,
|
||||
api_key=config.openai_api_key,
|
||||
)
|
||||
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
@@ -142,6 +152,7 @@ def generate_image_with_dalle(prompt: str, filename: str, size: int) -> str:
|
||||
def generate_image_with_sd_webui(
|
||||
prompt: str,
|
||||
filename: str,
|
||||
config: Config,
|
||||
size: int = 512,
|
||||
negative_prompt: str = "",
|
||||
extra: dict = {},
|
||||
@@ -158,13 +169,13 @@ def generate_image_with_sd_webui(
|
||||
"""
|
||||
# Create a session and set the basic auth if needed
|
||||
s = requests.Session()
|
||||
if CFG.sd_webui_auth:
|
||||
username, password = CFG.sd_webui_auth.split(":")
|
||||
if config.sd_webui_auth:
|
||||
username, password = config.sd_webui_auth.split(":")
|
||||
s.auth = (username, password or "")
|
||||
|
||||
# Generate the images
|
||||
response = requests.post(
|
||||
f"{CFG.sd_webui_url}/sdapi/v1/txt2img",
|
||||
f"{config.sd_webui_url}/sdapi/v1/txt2img",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
|
||||
Reference in New Issue
Block a user