Black and isort

This commit is contained in:
Tim Dockhorn
2024-02-29 12:35:51 -08:00
parent 1e30a2df80
commit c51e4e30c2
4 changed files with 90 additions and 55 deletions

View File

@@ -1,41 +1,42 @@
# Adding this at the very top of app.py to make 'generative-models' directory discoverable
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), 'generative-models'))
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "generative-models"))
import math
import random
import uuid
from glob import glob
from pathlib import Path
from typing import Optional
import cv2
import gradio as gr
import numpy as np
import torch
from einops import rearrange, repeat
from fire import Fire
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from PIL import Image
from torchvision.transforms import ToTensor
from scripts.sampling.simple_video_sample import (
get_batch, get_unique_embedder_keys_from_conditioner, load_model)
from scripts.util.detection.nsfw_and_watermark_dectection import \
DeepFloydDataFiltering
from sgm.inference.helpers import embed_watermark
from sgm.util import default, instantiate_from_config
from scripts.sampling.simple_video_sample import load_model, get_unique_embedder_keys_from_conditioner, get_batch
import gradio as gr
import uuid
import random
from huggingface_hub import hf_hub_download
# To download all svd models
#hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt", filename="svd_xt.safetensors", local_dir="checkpoints")
#hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid", filename="svd.safetensors", local_dir="checkpoints")
#hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt-1-1", filename="svd_xt_1_1.safetensors", local_dir="checkpoints")
# hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt", filename="svd_xt.safetensors", local_dir="checkpoints")
# hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid", filename="svd.safetensors", local_dir="checkpoints")
# hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt-1-1", filename="svd_xt_1_1.safetensors", local_dir="checkpoints")
# Define the repo, local directory and filename
repo_id="stabilityai/stable-video-diffusion-img2vid-xt-1-1" # replace with "stabilityai/stable-video-diffusion-img2vid-xt" or "stabilityai/stable-video-diffusion-img2vid" for other models
repo_id = "stabilityai/stable-video-diffusion-img2vid-xt-1-1" # replace with "stabilityai/stable-video-diffusion-img2vid-xt" or "stabilityai/stable-video-diffusion-img2vid" for other models
filename = "svd_xt_1_1.safetensors" # replace with "svd_xt.safetensors" or "svd.safetensors" for other models
local_dir = "checkpoints"
local_file_path = os.path.join(local_dir, filename)
@@ -43,11 +44,7 @@ local_file_path = os.path.join(local_dir, filename)
# Check if the file already exists
if not os.path.exists(local_file_path):
# If the file doesn't exist, download it
hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=local_dir
)
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
print("File downloaded.")
else:
print("File already exists. No need to download.")
@@ -71,6 +68,7 @@ model, filter = load_model(
num_steps,
)
def sample(
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
seed: Optional[int] = None,
@@ -82,14 +80,14 @@ def sample(
decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device: str = "cuda",
output_folder: str = "outputs",
progress=gr.Progress(track_tqdm=True)
progress=gr.Progress(track_tqdm=True),
):
"""
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
"""
fps_id = int(fps_id ) #casting float slider values to int)
if(randomize_seed):
fps_id = int(fps_id) # casting float slider values to int)
if randomize_seed:
seed = random.randint(0, max_64_bit_int)
torch.manual_seed(seed)
@@ -260,23 +258,50 @@ def resize_image(image_path, output_size=(1024, 576)):
return cropped_image
with gr.Blocks() as demo:
gr.Markdown('''# Community demo for Stable Video Diffusion - Img2Vid - XT ([model](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt), [paper](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets))
gr.Markdown(
"""# Community demo for Stable Video Diffusion - Img2Vid - XT ([model](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt), [paper](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets))
#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/blob/main/LICENSE)): generate `4s` vid from a single image at (`25 frames` at `6 fps`). Generation takes ~60s in an A100. [Join the waitlist for Stability's upcoming web experience](https://stability.ai/contact).
''')
"""
)
with gr.Row():
with gr.Column():
image = gr.Image(label="Upload your image", type="filepath")
generate_btn = gr.Button("Generate")
video = gr.Video()
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(label="Seed", value=42, randomize=True, minimum=0, maximum=max_64_bit_int, step=1)
seed = gr.Slider(
label="Seed",
value=42,
randomize=True,
minimum=0,
maximum=max_64_bit_int,
step=1,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
motion_bucket_id = gr.Slider(label="Motion bucket id", info="Controls how much motion to add/remove from the image", value=127, minimum=1, maximum=255)
fps_id = gr.Slider(label="Frames per second", info="The length of your video in seconds will be 25/fps", value=6, minimum=5, maximum=30)
motion_bucket_id = gr.Slider(
label="Motion bucket id",
info="Controls how much motion to add/remove from the image",
value=127,
minimum=1,
maximum=255,
)
fps_id = gr.Slider(
label="Frames per second",
info="The length of your video in seconds will be 25/fps",
value=6,
minimum=5,
maximum=30,
)
image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
generate_btn.click(fn=sample, inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id], outputs=[video, seed], api_name="video")
generate_btn.click(
fn=sample,
inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id],
outputs=[video, seed],
api_name="video",
)
if __name__ == "__main__":
demo.queue(max_size=20)

View File

@@ -1,5 +1,6 @@
from streamlit_helpers import *
from st_keyup import st_keyup
from streamlit_helpers import *
from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler
VERSION2SPECS = {
@@ -203,8 +204,12 @@ if __name__ == "__main__":
),
)
sampler.n_sample_steps = n_steps
default_prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."
prompt = st_keyup("Enter a value", value=default_prompt, debounce=300, key="interactive_text")
default_prompt = (
"A cinematic shot of a baby racoon wearing an intricate italian priest robe."
)
prompt = st_keyup(
"Enter a value", value=default_prompt, debounce=300, key="interactive_text"
)
cols = st.columns([1, 5, 1])
if mode != "skip":
@@ -217,7 +222,13 @@ if __name__ == "__main__":
sampler.noise_sampler = SeededNoise(seed=st.session_state.seed)
out = sample(
model, sampler, H=512, W=512, seed=st.session_state.seed, prompt=prompt, filter=state.get("filter")
model,
sampler,
H=512,
W=512,
seed=st.session_state.seed,
prompt=prompt,
filter=state.get("filter"),
)
with cols[1]:
st.image(out[0])

View File

@@ -3,14 +3,12 @@ from typing import Callable, Iterable, Union
import torch
from einops import rearrange, repeat
from sgm.modules.diffusionmodules.model import (
XFORMERS_IS_AVAILABLE,
AttnBlock,
Decoder,
from sgm.modules.diffusionmodules.model import (XFORMERS_IS_AVAILABLE,
AttnBlock, Decoder,
MemoryEfficientAttnBlock,
ResnetBlock,
)
from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding
ResnetBlock)
from sgm.modules.diffusionmodules.openaimodel import (ResBlock,
timestep_embedding)
from sgm.modules.video_attention import VideoTransformerBlock
from sgm.util import partialclass

View File

@@ -1,7 +1,8 @@
import torch
from ..modules.attention import *
from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding
from ..modules.diffusionmodules.util import (AlphaBlender, linear,
timestep_embedding)
class TimeMixSequential(nn.Sequential):