Black and isort

This commit is contained in:
Tim Dockhorn
2024-02-29 12:35:51 -08:00
parent 1e30a2df80
commit c51e4e30c2
4 changed files with 90 additions and 55 deletions

View File

@@ -1,41 +1,42 @@
# Adding this at the very top of app.py to make 'generative-models' directory discoverable # Adding this at the very top of app.py to make 'generative-models' directory discoverable
import sys
import os import os
sys.path.append(os.path.join(os.path.dirname(__file__), 'generative-models')) import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "generative-models"))
import math import math
import random
import uuid
from glob import glob from glob import glob
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import cv2 import cv2
import gradio as gr
import numpy as np import numpy as np
import torch import torch
from einops import rearrange, repeat from einops import rearrange, repeat
from fire import Fire from fire import Fire
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image from PIL import Image
from torchvision.transforms import ToTensor from torchvision.transforms import ToTensor
from scripts.sampling.simple_video_sample import (
get_batch, get_unique_embedder_keys_from_conditioner, load_model)
from scripts.util.detection.nsfw_and_watermark_dectection import \ from scripts.util.detection.nsfw_and_watermark_dectection import \
DeepFloydDataFiltering DeepFloydDataFiltering
from sgm.inference.helpers import embed_watermark from sgm.inference.helpers import embed_watermark
from sgm.util import default, instantiate_from_config from sgm.util import default, instantiate_from_config
from scripts.sampling.simple_video_sample import load_model, get_unique_embedder_keys_from_conditioner, get_batch
import gradio as gr
import uuid
import random
from huggingface_hub import hf_hub_download
# To download all svd models # To download all svd models
#hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt", filename="svd_xt.safetensors", local_dir="checkpoints") # hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt", filename="svd_xt.safetensors", local_dir="checkpoints")
#hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid", filename="svd.safetensors", local_dir="checkpoints") # hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid", filename="svd.safetensors", local_dir="checkpoints")
#hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt-1-1", filename="svd_xt_1_1.safetensors", local_dir="checkpoints") # hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt-1-1", filename="svd_xt_1_1.safetensors", local_dir="checkpoints")
# Define the repo, local directory and filename # Define the repo, local directory and filename
repo_id="stabilityai/stable-video-diffusion-img2vid-xt-1-1" # replace with "stabilityai/stable-video-diffusion-img2vid-xt" or "stabilityai/stable-video-diffusion-img2vid" for other models repo_id = "stabilityai/stable-video-diffusion-img2vid-xt-1-1" # replace with "stabilityai/stable-video-diffusion-img2vid-xt" or "stabilityai/stable-video-diffusion-img2vid" for other models
filename = "svd_xt_1_1.safetensors" # replace with "svd_xt.safetensors" or "svd.safetensors" for other models filename = "svd_xt_1_1.safetensors" # replace with "svd_xt.safetensors" or "svd.safetensors" for other models
local_dir = "checkpoints" local_dir = "checkpoints"
local_file_path = os.path.join(local_dir, filename) local_file_path = os.path.join(local_dir, filename)
@@ -43,17 +44,13 @@ local_file_path = os.path.join(local_dir, filename)
# Check if the file already exists # Check if the file already exists
if not os.path.exists(local_file_path): if not os.path.exists(local_file_path):
# If the file doesn't exist, download it # If the file doesn't exist, download it
hf_hub_download( hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
repo_id=repo_id,
filename=filename,
local_dir=local_dir
)
print("File downloaded.") print("File downloaded.")
else: else:
print("File already exists. No need to download.") print("File already exists. No need to download.")
version = "svd_xt_1_1" # replace with 'svd_xt' or 'svd' for other models version = "svd_xt_1_1" # replace with 'svd_xt' or 'svd' for other models
device = "cuda" device = "cuda"
max_64_bit_int = 2**63 - 1 max_64_bit_int = 2**63 - 1
@@ -71,6 +68,7 @@ model, filter = load_model(
num_steps, num_steps,
) )
def sample( def sample(
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
seed: Optional[int] = None, seed: Optional[int] = None,
@@ -82,14 +80,14 @@ def sample(
decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device: str = "cuda", device: str = "cuda",
output_folder: str = "outputs", output_folder: str = "outputs",
progress=gr.Progress(track_tqdm=True) progress=gr.Progress(track_tqdm=True),
): ):
""" """
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
""" """
fps_id = int(fps_id ) #casting float slider values to int) fps_id = int(fps_id) # casting float slider values to int)
if(randomize_seed): if randomize_seed:
seed = random.randint(0, max_64_bit_int) seed = random.randint(0, max_64_bit_int)
torch.manual_seed(seed) torch.manual_seed(seed)
@@ -260,23 +258,50 @@ def resize_image(image_path, output_size=(1024, 576)):
return cropped_image return cropped_image
with gr.Blocks() as demo:
gr.Markdown('''# Community demo for Stable Video Diffusion - Img2Vid - XT ([model](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt), [paper](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets))
#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/blob/main/LICENSE)): generate `4s` vid from a single image at (`25 frames` at `6 fps`). Generation takes ~60s in an A100. [Join the waitlist for Stability's upcoming web experience](https://stability.ai/contact).
''')
with gr.Row():
with gr.Column():
image = gr.Image(label="Upload your image", type="filepath")
generate_btn = gr.Button("Generate")
video = gr.Video()
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(label="Seed", value=42, randomize=True, minimum=0, maximum=max_64_bit_int, step=1)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
motion_bucket_id = gr.Slider(label="Motion bucket id", info="Controls how much motion to add/remove from the image", value=127, minimum=1, maximum=255)
fps_id = gr.Slider(label="Frames per second", info="The length of your video in seconds will be 25/fps", value=6, minimum=5, maximum=30)
image.upload(fn=resize_image, inputs=image, outputs=image, queue=False) with gr.Blocks() as demo:
generate_btn.click(fn=sample, inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id], outputs=[video, seed], api_name="video") gr.Markdown(
"""# Community demo for Stable Video Diffusion - Img2Vid - XT ([model](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt), [paper](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets))
#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/blob/main/LICENSE)): generate `4s` vid from a single image at (`25 frames` at `6 fps`). Generation takes ~60s in an A100. [Join the waitlist for Stability's upcoming web experience](https://stability.ai/contact).
"""
)
with gr.Row():
with gr.Column():
image = gr.Image(label="Upload your image", type="filepath")
generate_btn = gr.Button("Generate")
video = gr.Video()
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(
label="Seed",
value=42,
randomize=True,
minimum=0,
maximum=max_64_bit_int,
step=1,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
motion_bucket_id = gr.Slider(
label="Motion bucket id",
info="Controls how much motion to add/remove from the image",
value=127,
minimum=1,
maximum=255,
)
fps_id = gr.Slider(
label="Frames per second",
info="The length of your video in seconds will be 25/fps",
value=6,
minimum=5,
maximum=30,
)
image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
generate_btn.click(
fn=sample,
inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id],
outputs=[video, seed],
api_name="video",
)
if __name__ == "__main__": if __name__ == "__main__":
demo.queue(max_size=20) demo.queue(max_size=20)

View File

@@ -1,5 +1,6 @@
from streamlit_helpers import *
from st_keyup import st_keyup from st_keyup import st_keyup
from streamlit_helpers import *
from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler
VERSION2SPECS = { VERSION2SPECS = {
@@ -203,8 +204,12 @@ if __name__ == "__main__":
), ),
) )
sampler.n_sample_steps = n_steps sampler.n_sample_steps = n_steps
default_prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe." default_prompt = (
prompt = st_keyup("Enter a value", value=default_prompt, debounce=300, key="interactive_text") "A cinematic shot of a baby racoon wearing an intricate italian priest robe."
)
prompt = st_keyup(
"Enter a value", value=default_prompt, debounce=300, key="interactive_text"
)
cols = st.columns([1, 5, 1]) cols = st.columns([1, 5, 1])
if mode != "skip": if mode != "skip":
@@ -217,7 +222,13 @@ if __name__ == "__main__":
sampler.noise_sampler = SeededNoise(seed=st.session_state.seed) sampler.noise_sampler = SeededNoise(seed=st.session_state.seed)
out = sample( out = sample(
model, sampler, H=512, W=512, seed=st.session_state.seed, prompt=prompt, filter=state.get("filter") model,
sampler,
H=512,
W=512,
seed=st.session_state.seed,
prompt=prompt,
filter=state.get("filter"),
) )
with cols[1]: with cols[1]:
st.image(out[0]) st.image(out[0])

View File

@@ -3,14 +3,12 @@ from typing import Callable, Iterable, Union
import torch import torch
from einops import rearrange, repeat from einops import rearrange, repeat
from sgm.modules.diffusionmodules.model import ( from sgm.modules.diffusionmodules.model import (XFORMERS_IS_AVAILABLE,
XFORMERS_IS_AVAILABLE, AttnBlock, Decoder,
AttnBlock, MemoryEfficientAttnBlock,
Decoder, ResnetBlock)
MemoryEfficientAttnBlock, from sgm.modules.diffusionmodules.openaimodel import (ResBlock,
ResnetBlock, timestep_embedding)
)
from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding
from sgm.modules.video_attention import VideoTransformerBlock from sgm.modules.video_attention import VideoTransformerBlock
from sgm.util import partialclass from sgm.util import partialclass

View File

@@ -1,7 +1,8 @@
import torch import torch
from ..modules.attention import * from ..modules.attention import *
from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding from ..modules.diffusionmodules.util import (AlphaBlender, linear,
timestep_embedding)
class TimeMixSequential(nn.Sequential): class TimeMixSequential(nn.Sequential):