diff --git a/scripts/commands.py b/scripts/commands.py index fc10d1d0..bf8d7983 100644 --- a/scripts/commands.py +++ b/scripts/commands.py @@ -9,6 +9,7 @@ import ai_functions as ai from file_operations import read_file, write_to_file, append_to_file, delete_file, search_files from execute_code import execute_python_file from json_parser import fix_and_parse_json +from image_gen import generate_image from duckduckgo_search import ddg from googleapiclient.discovery import build from googleapiclient.errors import HttpError @@ -102,6 +103,8 @@ def execute_command(command_name, arguments): return ai.write_tests(arguments["code"], arguments.get("focus")) elif command_name == "execute_python_file": # Add this command return execute_python_file(arguments["file"]) + elif command_name == "generate_image": # Add this command + return generate_image(arguments["prompt"]) elif command_name == "task_complete": shutdown() else: diff --git a/scripts/config.py b/scripts/config.py index fe48d298..2eca1675 100644 --- a/scripts/config.py +++ b/scripts/config.py @@ -53,6 +53,8 @@ class Config(metaclass=Singleton): self.pinecone_api_key = os.getenv("PINECONE_API_KEY") self.pinecone_region = os.getenv("PINECONE_ENV") + self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN") + # User agent headers to use when browsing web # Some websites might just completely deny request with an error code if no user agent was found. self.user_agent_header = {"User-Agent":"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36"} diff --git a/scripts/data/prompt.txt b/scripts/data/prompt.txt index 28797d9e..363342c0 100644 --- a/scripts/data/prompt.txt +++ b/scripts/data/prompt.txt @@ -23,6 +23,7 @@ COMMANDS: 17. Write Tests: "write_tests", args: "code": "", "focus": "" 18. Execute Python File: "execute_python_file", args: "file": "" 19. Task Complete (Shutdown): "task_complete", args: "reason": "" +20. Generate Image: "generate_image", args: "prompt": "" RESOURCES: diff --git a/scripts/image_gen.py b/scripts/image_gen.py index cdc4fc4d..bb3e7686 100644 --- a/scripts/image_gen.py +++ b/scripts/image_gen.py @@ -1,44 +1,28 @@ -from kandinsky2 import get_kandinsky2 +import requests +import io +import os.path +from PIL import Image from config import Config +import uuid cfg = Config() +working_directory = "auto_gpt_workspace" + +API_URL = "https://api-inference.huggingface.co/models/CompVis/stable-diffusion-v1-4" +headers = {"Authorization": "Bearer " + cfg.huggingface_api_token} + def generate_image(prompt): - - model = get_kandinsky2('cuda', task_type='text2img', model_version='2.1', use_flash_attention=False) - images = model.generate_text2img( - "red cat, 4k photo", # prompt - num_steps=100, - batch_size=1, - guidance_scale=4, - h=768, w=768, - sampler='p_sampler', - prior_cf_scale=4, - prior_steps="5" - ) - return images - - # base_url = 'http://export.arxiv.org/api/query?' - # query = f'search_query=all:{search_query}&start=0&max_results={max_results}' - # url = base_url + query - # response = requests.get(url) + response = requests.post(API_URL, headers=headers, json={ + "inputs": prompt, + }) + image = Image.open(io.BytesIO(response.content)) + print("Image Generated for prompt:" + prompt) - # if response.status_code == 200: - # soup = BeautifulSoup(response.content, 'xml') - # entries = soup.find_all('entry') + filename = str(uuid.uuid4()) + ".jpg" - # articles = [] - # for entry in entries: - # title = entry.title.text.strip() - # url = entry.id.text.strip() - # published = entry.published.text.strip() + image.save(os.path.join(working_directory, filename)) - # articles.append({ - # 'title': title, - # 'url': url, - # 'published': published - # }) + print("Saved to disk:" + filename) - # return articles - # else: - # return None + return str("Image " + filename + " saved to disk for prompt: " + prompt)