mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-17 22:14:28 +01:00
feat(ImageGen): add stable diffusion support
This commit is contained in:
@@ -9,6 +9,7 @@ import ai_functions as ai
|
|||||||
from file_operations import read_file, write_to_file, append_to_file, delete_file, search_files
|
from file_operations import read_file, write_to_file, append_to_file, delete_file, search_files
|
||||||
from execute_code import execute_python_file
|
from execute_code import execute_python_file
|
||||||
from json_parser import fix_and_parse_json
|
from json_parser import fix_and_parse_json
|
||||||
|
from image_gen import generate_image
|
||||||
from duckduckgo_search import ddg
|
from duckduckgo_search import ddg
|
||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
from googleapiclient.errors import HttpError
|
from googleapiclient.errors import HttpError
|
||||||
@@ -102,6 +103,8 @@ def execute_command(command_name, arguments):
|
|||||||
return ai.write_tests(arguments["code"], arguments.get("focus"))
|
return ai.write_tests(arguments["code"], arguments.get("focus"))
|
||||||
elif command_name == "execute_python_file": # Add this command
|
elif command_name == "execute_python_file": # Add this command
|
||||||
return execute_python_file(arguments["file"])
|
return execute_python_file(arguments["file"])
|
||||||
|
elif command_name == "generate_image": # Add this command
|
||||||
|
return generate_image(arguments["prompt"])
|
||||||
elif command_name == "task_complete":
|
elif command_name == "task_complete":
|
||||||
shutdown()
|
shutdown()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -53,6 +53,8 @@ class Config(metaclass=Singleton):
|
|||||||
self.pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
self.pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
||||||
self.pinecone_region = os.getenv("PINECONE_ENV")
|
self.pinecone_region = os.getenv("PINECONE_ENV")
|
||||||
|
|
||||||
|
self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN")
|
||||||
|
|
||||||
# User agent headers to use when browsing web
|
# User agent headers to use when browsing web
|
||||||
# Some websites might just completely deny request with an error code if no user agent was found.
|
# Some websites might just completely deny request with an error code if no user agent was found.
|
||||||
self.user_agent_header = {"User-Agent":"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36"}
|
self.user_agent_header = {"User-Agent":"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36"}
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ COMMANDS:
|
|||||||
17. Write Tests: "write_tests", args: "code": "<full_code_string>", "focus": "<list_of_focus_areas>"
|
17. Write Tests: "write_tests", args: "code": "<full_code_string>", "focus": "<list_of_focus_areas>"
|
||||||
18. Execute Python File: "execute_python_file", args: "file": "<file>"
|
18. Execute Python File: "execute_python_file", args: "file": "<file>"
|
||||||
19. Task Complete (Shutdown): "task_complete", args: "reason": "<reason>"
|
19. Task Complete (Shutdown): "task_complete", args: "reason": "<reason>"
|
||||||
|
20. Generate Image: "generate_image", args: "prompt": "<prompt>"
|
||||||
|
|
||||||
RESOURCES:
|
RESOURCES:
|
||||||
|
|
||||||
|
|||||||
@@ -1,44 +1,28 @@
|
|||||||
from kandinsky2 import get_kandinsky2
|
import requests
|
||||||
|
import io
|
||||||
|
import os.path
|
||||||
|
from PIL import Image
|
||||||
from config import Config
|
from config import Config
|
||||||
|
import uuid
|
||||||
|
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
|
||||||
|
working_directory = "auto_gpt_workspace"
|
||||||
|
|
||||||
|
API_URL = "https://api-inference.huggingface.co/models/CompVis/stable-diffusion-v1-4"
|
||||||
|
headers = {"Authorization": "Bearer " + cfg.huggingface_api_token}
|
||||||
|
|
||||||
def generate_image(prompt):
|
def generate_image(prompt):
|
||||||
|
response = requests.post(API_URL, headers=headers, json={
|
||||||
model = get_kandinsky2('cuda', task_type='text2img', model_version='2.1', use_flash_attention=False)
|
"inputs": prompt,
|
||||||
images = model.generate_text2img(
|
})
|
||||||
"red cat, 4k photo", # prompt
|
image = Image.open(io.BytesIO(response.content))
|
||||||
num_steps=100,
|
print("Image Generated for prompt:" + prompt)
|
||||||
batch_size=1,
|
|
||||||
guidance_scale=4,
|
|
||||||
h=768, w=768,
|
|
||||||
sampler='p_sampler',
|
|
||||||
prior_cf_scale=4,
|
|
||||||
prior_steps="5"
|
|
||||||
)
|
|
||||||
return images
|
|
||||||
|
|
||||||
# base_url = 'http://export.arxiv.org/api/query?'
|
|
||||||
# query = f'search_query=all:{search_query}&start=0&max_results={max_results}'
|
|
||||||
# url = base_url + query
|
|
||||||
# response = requests.get(url)
|
|
||||||
|
|
||||||
# if response.status_code == 200:
|
filename = str(uuid.uuid4()) + ".jpg"
|
||||||
# soup = BeautifulSoup(response.content, 'xml')
|
|
||||||
# entries = soup.find_all('entry')
|
|
||||||
|
|
||||||
# articles = []
|
image.save(os.path.join(working_directory, filename))
|
||||||
# for entry in entries:
|
|
||||||
# title = entry.title.text.strip()
|
|
||||||
# url = entry.id.text.strip()
|
|
||||||
# published = entry.published.text.strip()
|
|
||||||
|
|
||||||
# articles.append({
|
print("Saved to disk:" + filename)
|
||||||
# 'title': title,
|
|
||||||
# 'url': url,
|
|
||||||
# 'published': published
|
|
||||||
# })
|
|
||||||
|
|
||||||
# return articles
|
return str("Image " + filename + " saved to disk for prompt: " + prompt)
|
||||||
# else:
|
|
||||||
# return None
|
|
||||||
|
|||||||
Reference in New Issue
Block a user