mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 06:14:21 +01:00
SV3D inference code (#300)
* Makes init changes for SV3D * Small fixes : cond_aug * Fixes SV3D checkpoint, fixes rembg * Black formatting * Adds streamlit demo, fixes simple sample script * Removes SV3D video_decoder, keeps SV3D image_decoder * Updates README * Minor updates * Remove GSO script --------- Co-authored-by: Vikram Voleti <vikram@ip-26-0-153-234.us-west-2.compute.internal>
This commit is contained in:
18
README.md
18
README.md
@@ -3,6 +3,24 @@
|
|||||||

|

|
||||||
|
|
||||||
## News
|
## News
|
||||||
|
|
||||||
|
**March 18, 2024**
|
||||||
|
- We are releasing [SV3D](https://huggingface.co/stabilityai/sv3d), an image-to-video model for novel multi-view synthesis, for research purposes:
|
||||||
|
- SV3D was trained to generate 21 frames at resolution 576x576, given 1 context frame of the same size, ideally a white-background image with one object.
|
||||||
|
- SV3D_u: This variant generates orbital videos based on single image inputs without camera conditioning..
|
||||||
|
- SV3D_p: Extending the capability of SVD3_u, this variant accommodates both single images and orbital views allowing for the creation of 3D video along specified camera paths.
|
||||||
|
- We extend the streamlit demo `scripts/demo/video_sampling.py` and the standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models.
|
||||||
|
- Please check our [project page](https://sv3d.github.io), [tech report](https://sv3d.github.io/static/paper.pdf) and [video summary](https://youtu.be/Zqw4-1LcfWg) for more details.
|
||||||
|
|
||||||
|
To run SV3D on a single image:
|
||||||
|
`python scripts/sampling/simple_video_sample.py --input_path <path/to/image.png> --version sv3d_p`
|
||||||
|
|
||||||
|
To run SVD or SV3D on a streamlit server:
|
||||||
|
`streamlit run scripts/demo/video_sampling.py`
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
**November 30, 2023**
|
**November 30, 2023**
|
||||||
- Following the launch of SDXL-Turbo, we are releasing [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo).
|
- Following the launch of SDXL-Turbo, we are releasing [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo).
|
||||||
|
|
||||||
|
|||||||
BIN
assets/sv3d.gif
Normal file
BIN
assets/sv3d.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.2 MiB |
118
configs/inference/sv3d_p.yaml
Normal file
118
configs/inference/sv3d_p.yaml
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||||
|
params:
|
||||||
|
adm_in_channels: 1280
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
merge_strategy: learned_with_images
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- input_key: cond_frames_without_noise
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
params:
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
disable_encoder_autocast: True
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
is_ae: True
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
- input_key: cond_aug
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: polars_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
- input_key: azimuths_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Decoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
106
configs/inference/sv3d_u.yaml
Normal file
106
configs/inference/sv3d_u.yaml
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||||
|
params:
|
||||||
|
adm_in_channels: 256
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
merge_strategy: learned_with_images
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- input_key: cond_frames_without_noise
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
params:
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
disable_encoder_autocast: True
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
is_ae: True
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
- input_key: cond_aug
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Decoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
41
model_licenses/LICENSE-SV3D
Normal file
41
model_licenses/LICENSE-SV3D
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
STABILITY AI NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT
|
||||||
|
Dated: March 18, 2024
|
||||||
|
|
||||||
|
"Agreement" means this Stable Non-Commercial Research Community License Agreement.
|
||||||
|
|
||||||
|
“AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
|
||||||
|
|
||||||
|
"Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws, (b) any modifications to a Model, and (c) any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
|
||||||
|
|
||||||
|
“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
|
||||||
|
|
||||||
|
"Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
||||||
|
|
||||||
|
“Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.
|
||||||
|
|
||||||
|
“Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works.
|
||||||
|
|
||||||
|
"Stability AI" or "we" means Stability AI Ltd and its affiliates.
|
||||||
|
|
||||||
|
|
||||||
|
"Software" means Stability AI’s proprietary software made available under this Agreement.
|
||||||
|
|
||||||
|
“Software Products” means the Models, Software and Documentation, individually or in any combination.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
1. License Rights and Redistribution.
|
||||||
|
a. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to use, reproduce, distribute, and create Derivative Works of, the Software Products, in each case for Non-Commercial Uses only.
|
||||||
|
b. You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact.
|
||||||
|
c. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
|
||||||
|
2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
|
||||||
|
3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
||||||
|
4. Intellectual Property.
|
||||||
|
a. No trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works.
|
||||||
|
b. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works
|
||||||
|
c. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement.
|
||||||
|
5. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement.
|
||||||
|
|
||||||
|
6. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the United States and the State of California without regard to choice of law
|
||||||
|
principles.
|
||||||
|
|
||||||
@@ -19,6 +19,7 @@ pillow>=9.5.0
|
|||||||
pudb>=2022.1.3
|
pudb>=2022.1.3
|
||||||
pytorch-lightning==2.0.1
|
pytorch-lightning==2.0.1
|
||||||
pyyaml>=6.0.1
|
pyyaml>=6.0.1
|
||||||
|
rembg
|
||||||
scipy>=1.10.1
|
scipy>=1.10.1
|
||||||
streamlit>=0.73.1
|
streamlit>=0.73.1
|
||||||
tensorboardx==2.6
|
tensorboardx==2.6
|
||||||
|
|||||||
@@ -23,9 +23,11 @@ from PIL import Image
|
|||||||
from torchvision.transforms import ToTensor
|
from torchvision.transforms import ToTensor
|
||||||
|
|
||||||
from scripts.sampling.simple_video_sample import (
|
from scripts.sampling.simple_video_sample import (
|
||||||
get_batch, get_unique_embedder_keys_from_conditioner, load_model)
|
get_batch,
|
||||||
from scripts.util.detection.nsfw_and_watermark_dectection import \
|
get_unique_embedder_keys_from_conditioner,
|
||||||
DeepFloydDataFiltering
|
load_model,
|
||||||
|
)
|
||||||
|
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||||
from sgm.inference.helpers import embed_watermark
|
from sgm.inference.helpers import embed_watermark
|
||||||
from sgm.util import default, instantiate_from_config
|
from sgm.util import default, instantiate_from_config
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from glob import glob
|
|||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import imageio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import torch
|
import torch
|
||||||
@@ -15,25 +16,30 @@ from imwatermark import WatermarkEncoder
|
|||||||
from omegaconf import ListConfig, OmegaConf
|
from omegaconf import ListConfig, OmegaConf
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors.torch import load_file as load_safetensors
|
from safetensors.torch import load_file as load_safetensors
|
||||||
|
from scripts.demo.discretization import (
|
||||||
|
Img2ImgDiscretizationWrapper,
|
||||||
|
Txt2NoisyDiscretizationWrapper,
|
||||||
|
)
|
||||||
|
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||||
|
from sgm.inference.helpers import embed_watermark
|
||||||
|
from sgm.modules.diffusionmodules.guiders import (
|
||||||
|
LinearPredictionGuider,
|
||||||
|
TrianglePredictionGuider,
|
||||||
|
VanillaCFG,
|
||||||
|
)
|
||||||
|
from sgm.modules.diffusionmodules.sampling import (
|
||||||
|
DPMPP2MSampler,
|
||||||
|
DPMPP2SAncestralSampler,
|
||||||
|
EulerAncestralSampler,
|
||||||
|
EulerEDMSampler,
|
||||||
|
HeunEDMSampler,
|
||||||
|
LinearMultistepSampler,
|
||||||
|
)
|
||||||
|
from sgm.util import append_dims, default, instantiate_from_config
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.utils import make_grid, save_image
|
from torchvision.utils import make_grid, save_image
|
||||||
|
|
||||||
from scripts.demo.discretization import (Img2ImgDiscretizationWrapper,
|
|
||||||
Txt2NoisyDiscretizationWrapper)
|
|
||||||
from scripts.util.detection.nsfw_and_watermark_dectection import \
|
|
||||||
DeepFloydDataFiltering
|
|
||||||
from sgm.inference.helpers import embed_watermark
|
|
||||||
from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider,
|
|
||||||
VanillaCFG)
|
|
||||||
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
|
|
||||||
DPMPP2SAncestralSampler,
|
|
||||||
EulerAncestralSampler,
|
|
||||||
EulerEDMSampler,
|
|
||||||
HeunEDMSampler,
|
|
||||||
LinearMultistepSampler)
|
|
||||||
from sgm.util import append_dims, default, instantiate_from_config
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache_resource()
|
@st.cache_resource()
|
||||||
def init_st(version_dict, load_ckpt=True, load_filter=True):
|
def init_st(version_dict, load_ckpt=True, load_filter=True):
|
||||||
@@ -222,6 +228,7 @@ def get_guider(options, key):
|
|||||||
"VanillaCFG",
|
"VanillaCFG",
|
||||||
"IdentityGuider",
|
"IdentityGuider",
|
||||||
"LinearPredictionGuider",
|
"LinearPredictionGuider",
|
||||||
|
"TrianglePredictionGuider",
|
||||||
],
|
],
|
||||||
options.get("guider", 0),
|
options.get("guider", 0),
|
||||||
)
|
)
|
||||||
@@ -252,7 +259,7 @@ def get_guider(options, key):
|
|||||||
value=options.get("cfg", 1.5),
|
value=options.get("cfg", 1.5),
|
||||||
min_value=1.0,
|
min_value=1.0,
|
||||||
)
|
)
|
||||||
min_scale = st.number_input(
|
min_scale = st.sidebar.number_input(
|
||||||
f"min guidance scale",
|
f"min guidance scale",
|
||||||
value=options.get("min_cfg", 1.0),
|
value=options.get("min_cfg", 1.0),
|
||||||
min_value=1.0,
|
min_value=1.0,
|
||||||
@@ -268,6 +275,29 @@ def get_guider(options, key):
|
|||||||
**additional_guider_kwargs,
|
**additional_guider_kwargs,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
elif guider == "TrianglePredictionGuider":
|
||||||
|
max_scale = st.number_input(
|
||||||
|
f"max-cfg-scale #{key}",
|
||||||
|
value=options.get("cfg", 2.5),
|
||||||
|
min_value=1.0,
|
||||||
|
max_value=10.0,
|
||||||
|
)
|
||||||
|
min_scale = st.sidebar.number_input(
|
||||||
|
f"min guidance scale",
|
||||||
|
value=options.get("min_cfg", 1.0),
|
||||||
|
min_value=1.0,
|
||||||
|
max_value=10.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
guider_config = {
|
||||||
|
"target": "sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider",
|
||||||
|
"params": {
|
||||||
|
"max_scale": max_scale,
|
||||||
|
"min_scale": min_scale,
|
||||||
|
"num_frames": options["num_frames"],
|
||||||
|
**additional_guider_kwargs,
|
||||||
|
},
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return guider_config
|
return guider_config
|
||||||
@@ -288,8 +318,8 @@ def init_sampling(
|
|||||||
f"num cols #{key}", value=num_cols, min_value=1, max_value=10
|
f"num cols #{key}", value=num_cols, min_value=1, max_value=10
|
||||||
)
|
)
|
||||||
|
|
||||||
steps = st.sidebar.number_input(
|
steps = st.number_input(
|
||||||
f"steps #{key}", value=options.get("num_steps", 40), min_value=1, max_value=1000
|
f"steps #{key}", value=options.get("num_steps", 50), min_value=1, max_value=1000
|
||||||
)
|
)
|
||||||
sampler = st.sidebar.selectbox(
|
sampler = st.sidebar.selectbox(
|
||||||
f"Sampler #{key}",
|
f"Sampler #{key}",
|
||||||
@@ -337,13 +367,13 @@ def get_discretization(discretization, options, key=1):
|
|||||||
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
||||||
}
|
}
|
||||||
elif discretization == "EDMDiscretization":
|
elif discretization == "EDMDiscretization":
|
||||||
sigma_min = st.number_input(
|
sigma_min = st.sidebar.number_input(
|
||||||
f"sigma_min #{key}", value=options.get("sigma_min", 0.03)
|
f"sigma_min #{key}", value=options.get("sigma_min", 0.03)
|
||||||
) # 0.0292
|
) # 0.0292
|
||||||
sigma_max = st.number_input(
|
sigma_max = st.sidebar.number_input(
|
||||||
f"sigma_max #{key}", value=options.get("sigma_max", 14.61)
|
f"sigma_max #{key}", value=options.get("sigma_max", 14.61)
|
||||||
) # 14.6146
|
) # 14.6146
|
||||||
rho = st.number_input(f"rho #{key}", value=options.get("rho", 3.0))
|
rho = st.sidebar.number_input(f"rho #{key}", value=options.get("rho", 3.0))
|
||||||
discretization_config = {
|
discretization_config = {
|
||||||
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
||||||
"params": {
|
"params": {
|
||||||
@@ -542,7 +572,12 @@ def do_sample(
|
|||||||
assert T is not None
|
assert T is not None
|
||||||
|
|
||||||
if isinstance(
|
if isinstance(
|
||||||
sampler.guider, (VanillaCFG, LinearPredictionGuider)
|
sampler.guider,
|
||||||
|
(
|
||||||
|
VanillaCFG,
|
||||||
|
LinearPredictionGuider,
|
||||||
|
TrianglePredictionGuider,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
additional_model_inputs[k] = torch.zeros(
|
additional_model_inputs[k] = torch.zeros(
|
||||||
num_samples[0] * 2, num_samples[1]
|
num_samples[0] * 2, num_samples[1]
|
||||||
@@ -678,6 +713,12 @@ def get_batch(
|
|||||||
batch[key] = repeat(
|
batch[key] = repeat(
|
||||||
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
|
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
|
||||||
)
|
)
|
||||||
|
elif key == "polars_rad":
|
||||||
|
batch[key] = torch.tensor(value_dict["polars_rad"]).to(device).repeat(N[0])
|
||||||
|
elif key == "azimuths_rad":
|
||||||
|
batch[key] = (
|
||||||
|
torch.tensor(value_dict["azimuths_rad"]).to(device).repeat(N[0])
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
batch[key] = value_dict[key]
|
batch[key] = value_dict[key]
|
||||||
|
|
||||||
@@ -827,8 +868,13 @@ def load_img_for_prediction(
|
|||||||
st.image(image)
|
st.image(image)
|
||||||
w, h = image.size
|
w, h = image.size
|
||||||
|
|
||||||
image = np.array(image).transpose(2, 0, 1)
|
image = np.array(image).astype(np.float32) / 255
|
||||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0
|
if image.shape[-1] == 4:
|
||||||
|
rgb, alpha = image[:, :, :3], image[:, :, 3:]
|
||||||
|
image = rgb * alpha + (1 - alpha)
|
||||||
|
|
||||||
|
image = image.transpose(2, 0, 1)
|
||||||
|
image = torch.from_numpy(image).to(dtype=torch.float32)
|
||||||
image = image.unsqueeze(0)
|
image = image.unsqueeze(0)
|
||||||
|
|
||||||
rfs = get_resizing_factor((H, W), (h, w))
|
rfs = get_resizing_factor((H, W), (h, w))
|
||||||
@@ -860,28 +906,16 @@ def save_video_as_grid_and_mp4(
|
|||||||
save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4)
|
save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4)
|
||||||
|
|
||||||
video_path = os.path.join(save_path, f"{base_count:06d}.mp4")
|
video_path = os.path.join(save_path, f"{base_count:06d}.mp4")
|
||||||
|
|
||||||
writer = cv2.VideoWriter(
|
|
||||||
video_path,
|
|
||||||
cv2.VideoWriter_fourcc(*"MP4V"),
|
|
||||||
fps,
|
|
||||||
(vid.shape[-1], vid.shape[-2]),
|
|
||||||
)
|
|
||||||
|
|
||||||
vid = (
|
vid = (
|
||||||
(rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8)
|
(rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8)
|
||||||
)
|
)
|
||||||
for frame in vid:
|
imageio.mimwrite(video_path, vid, fps=fps)
|
||||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
|
||||||
writer.write(frame)
|
|
||||||
|
|
||||||
writer.release()
|
|
||||||
|
|
||||||
video_path_h264 = video_path[:-4] + "_h264.mp4"
|
video_path_h264 = video_path[:-4] + "_h264.mp4"
|
||||||
os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}")
|
os.system(f"ffmpeg -i '{video_path}' -c:v libx264 '{video_path_h264}'")
|
||||||
|
|
||||||
with open(video_path_h264, "rb") as f:
|
with open(video_path_h264, "rb") as f:
|
||||||
video_bytes = f.read()
|
video_bytes = f.read()
|
||||||
|
os.remove(video_path_h264)
|
||||||
st.video(video_bytes)
|
st.video(video_bytes)
|
||||||
|
|
||||||
base_count += 1
|
base_count += 1
|
||||||
|
|||||||
104
scripts/demo/sv3d_helpers.py
Normal file
104
scripts/demo/sv3d_helpers.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def generate_dynamic_cycle_xy_values(
|
||||||
|
length=21,
|
||||||
|
init_elev=0,
|
||||||
|
num_components=84,
|
||||||
|
frequency_range=(1, 5),
|
||||||
|
amplitude_range=(0.5, 10),
|
||||||
|
step_range=(0, 2),
|
||||||
|
):
|
||||||
|
# Y values generation
|
||||||
|
y_sequence = np.ones(length) * init_elev
|
||||||
|
for _ in range(num_components):
|
||||||
|
# Choose a frequency that will complete whole cycles in the sequence
|
||||||
|
frequency = np.random.randint(*frequency_range) * (2 * np.pi / length)
|
||||||
|
amplitude = np.random.uniform(*amplitude_range)
|
||||||
|
phase_shift = np.random.choice([0, np.pi]) # np.random.uniform(0, 2 * np.pi)
|
||||||
|
angles = (
|
||||||
|
np.linspace(0, frequency * length, length, endpoint=False) + phase_shift
|
||||||
|
)
|
||||||
|
y_sequence += np.sin(angles) * amplitude
|
||||||
|
# X values generation
|
||||||
|
# Generate length - 1 steps since the last step is back to start
|
||||||
|
steps = np.random.uniform(*step_range, length - 1)
|
||||||
|
total_step_sum = np.sum(steps)
|
||||||
|
# Calculate the scale factor to scale total steps to just under 360
|
||||||
|
scale_factor = (
|
||||||
|
360 - ((360 / length) * np.random.uniform(*step_range))
|
||||||
|
) / total_step_sum
|
||||||
|
# Apply the scale factor and generate the sequence of X values
|
||||||
|
x_values = np.cumsum(steps * scale_factor)
|
||||||
|
# Ensure the sequence starts at 0 and add the final step to complete the loop
|
||||||
|
x_values = np.insert(x_values, 0, 0)
|
||||||
|
return x_values, y_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def smooth_data(data, window_size):
|
||||||
|
# Extend data at both ends by wrapping around to create a continuous loop
|
||||||
|
pad_size = window_size
|
||||||
|
padded_data = np.concatenate((data[-pad_size:], data, data[:pad_size]))
|
||||||
|
|
||||||
|
# Apply smoothing
|
||||||
|
kernel = np.ones(window_size) / window_size
|
||||||
|
smoothed_data = np.convolve(padded_data, kernel, mode="same")
|
||||||
|
|
||||||
|
# Extract the smoothed data corresponding to the original sequence
|
||||||
|
# Adjust the indices to account for the larger padding
|
||||||
|
start_index = pad_size
|
||||||
|
end_index = -pad_size if pad_size != 0 else None
|
||||||
|
smoothed_original_data = smoothed_data[start_index:end_index]
|
||||||
|
return smoothed_original_data
|
||||||
|
|
||||||
|
|
||||||
|
# Function to generate and process the data
|
||||||
|
def gen_dynamic_loop(length=21, elev_deg=0):
|
||||||
|
while True:
|
||||||
|
# Generate the combined X and Y values using the new function
|
||||||
|
azim_values, elev_values = generate_dynamic_cycle_xy_values(
|
||||||
|
length=84, init_elev=elev_deg
|
||||||
|
)
|
||||||
|
# Smooth the Y values directly
|
||||||
|
smoothed_elev_values = smooth_data(elev_values, 5)
|
||||||
|
max_magnitude = np.max(np.abs(smoothed_elev_values))
|
||||||
|
if max_magnitude < 90:
|
||||||
|
break
|
||||||
|
subsample = 84 // length
|
||||||
|
azim_rad = np.deg2rad(azim_values[::subsample])
|
||||||
|
elev_rad = np.deg2rad(smoothed_elev_values[::subsample])
|
||||||
|
# Make cond frame the last one
|
||||||
|
return np.roll(azim_rad, -1), np.roll(elev_rad, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_3D(azim, polar, save_path, dynamic=True):
|
||||||
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
|
elev = np.deg2rad(90) - polar
|
||||||
|
fig = plt.figure(figsize=(5, 5))
|
||||||
|
ax = fig.add_subplot(projection="3d")
|
||||||
|
cm = plt.get_cmap("Greys")
|
||||||
|
col_line = [cm(i) for i in np.linspace(0.3, 1, len(azim) + 1)]
|
||||||
|
cm = plt.get_cmap("cool")
|
||||||
|
col = [cm(float(i) / (len(azim))) for i in np.arange(len(azim))]
|
||||||
|
xs = np.cos(elev) * np.cos(azim)
|
||||||
|
ys = np.cos(elev) * np.sin(azim)
|
||||||
|
zs = np.sin(elev)
|
||||||
|
ax.scatter(xs[0], ys[0], zs[0], s=100, color=col[0])
|
||||||
|
xs_d, ys_d, zs_d = (xs[1:] - xs[:-1]), (ys[1:] - ys[:-1]), (zs[1:] - zs[:-1])
|
||||||
|
for i in range(len(xs) - 1):
|
||||||
|
if dynamic:
|
||||||
|
ax.quiver(
|
||||||
|
xs[i], ys[i], zs[i], xs_d[i], ys_d[i], zs_d[i], lw=2, color=col_line[i]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax.plot(xs[i : i + 2], ys[i : i + 2], zs[i : i + 2], lw=2, c=col_line[i])
|
||||||
|
ax.scatter(xs[i + 1], ys[i + 1], zs[i + 1], s=100, color=col[i + 1])
|
||||||
|
ax.scatter(xs[:1], ys[:1], zs[:1], s=120, facecolors="none", edgecolors="k")
|
||||||
|
ax.scatter(xs[-1:], ys[-1:], zs[-1:], s=120, facecolors="none", edgecolors="k")
|
||||||
|
ax.view_init(elev=30, azim=-20, roll=0)
|
||||||
|
plt.savefig(save_path, bbox_inches="tight")
|
||||||
|
plt.clf()
|
||||||
|
plt.close()
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
|
|
||||||
from scripts.demo.streamlit_helpers import *
|
from scripts.demo.streamlit_helpers import *
|
||||||
|
from scripts.demo.sv3d_helpers import *
|
||||||
|
|
||||||
SAVE_PATH = "outputs/demo/vid/"
|
SAVE_PATH = "outputs/demo/vid/"
|
||||||
|
|
||||||
@@ -87,11 +89,51 @@ VERSION2SPECS = {
|
|||||||
"decoding_t": 14,
|
"decoding_t": 14,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"sv3d_u": {
|
||||||
|
"T": 21,
|
||||||
|
"H": 576,
|
||||||
|
"W": 576,
|
||||||
|
"C": 4,
|
||||||
|
"f": 8,
|
||||||
|
"config": "configs/inference/sv3d_u.yaml",
|
||||||
|
"ckpt": "checkpoints/sv3d_u.safetensors",
|
||||||
|
"options": {
|
||||||
|
"discretization": 1,
|
||||||
|
"cfg": 2.5,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
"sigma_max": 700.0,
|
||||||
|
"rho": 7.0,
|
||||||
|
"guider": 3,
|
||||||
|
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||||
|
"num_steps": 50,
|
||||||
|
"decoding_t": 14,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"sv3d_p": {
|
||||||
|
"T": 21,
|
||||||
|
"H": 576,
|
||||||
|
"W": 576,
|
||||||
|
"C": 4,
|
||||||
|
"f": 8,
|
||||||
|
"config": "configs/inference/sv3d_p.yaml",
|
||||||
|
"ckpt": "checkpoints/sv3d_p.safetensors",
|
||||||
|
"options": {
|
||||||
|
"discretization": 1,
|
||||||
|
"cfg": 2.5,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
"sigma_max": 700.0,
|
||||||
|
"rho": 7.0,
|
||||||
|
"guider": 3,
|
||||||
|
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||||
|
"num_steps": 50,
|
||||||
|
"decoding_t": 14,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
st.title("Stable Video Diffusion")
|
st.title("Stable Video Diffusion / SV3D")
|
||||||
version = st.selectbox(
|
version = st.selectbox(
|
||||||
"Model Version",
|
"Model Version",
|
||||||
[k for k in VERSION2SPECS.keys()],
|
[k for k in VERSION2SPECS.keys()],
|
||||||
@@ -131,17 +173,42 @@ if __name__ == "__main__":
|
|||||||
{},
|
{},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "fps" not in ukeys:
|
||||||
|
value_dict["fps"] = 10
|
||||||
|
|
||||||
value_dict["image_only_indicator"] = 0
|
value_dict["image_only_indicator"] = 0
|
||||||
|
|
||||||
if mode == "img2vid":
|
if mode == "img2vid":
|
||||||
img = load_img_for_prediction(W, H)
|
img = load_img_for_prediction(W, H)
|
||||||
cond_aug = st.number_input(
|
if "sv3d" in version:
|
||||||
"Conditioning augmentation:", value=0.02, min_value=0.0
|
cond_aug = 1e-5
|
||||||
)
|
else:
|
||||||
|
cond_aug = st.number_input(
|
||||||
|
"Conditioning augmentation:", value=0.02, min_value=0.0
|
||||||
|
)
|
||||||
value_dict["cond_frames_without_noise"] = img
|
value_dict["cond_frames_without_noise"] = img
|
||||||
value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img)
|
value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img)
|
||||||
value_dict["cond_aug"] = cond_aug
|
value_dict["cond_aug"] = cond_aug
|
||||||
|
|
||||||
|
if "sv3d_p" in version:
|
||||||
|
elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90)
|
||||||
|
trajectory = st.selectbox(
|
||||||
|
"Trajectory",
|
||||||
|
["same elevation", "dynamic"],
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
if trajectory == "same elevation":
|
||||||
|
value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T)
|
||||||
|
value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:]
|
||||||
|
elif trajectory == "dynamic":
|
||||||
|
azim_rad, elev_rad = gen_dynamic_loop(length=21, elev_deg=elev_deg)
|
||||||
|
value_dict["polars_rad"] = np.deg2rad(90) - elev_rad
|
||||||
|
value_dict["azimuths_rad"] = azim_rad
|
||||||
|
elif "sv3d_u" in version:
|
||||||
|
elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90)
|
||||||
|
value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T)
|
||||||
|
value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:]
|
||||||
|
|
||||||
seed = st.sidebar.number_input(
|
seed = st.sidebar.number_input(
|
||||||
"seed", value=23, min_value=0, max_value=int(1e9)
|
"seed", value=23, min_value=0, max_value=int(1e9)
|
||||||
)
|
)
|
||||||
@@ -151,6 +218,19 @@ if __name__ == "__main__":
|
|||||||
os.path.join(SAVE_PATH, version), init_value=True
|
os.path.join(SAVE_PATH, version), init_value=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "sv3d" in version:
|
||||||
|
plot_save_path = os.path.join(save_path, "plot_3D.png")
|
||||||
|
plot_3D(
|
||||||
|
azim=value_dict["azimuths_rad"],
|
||||||
|
polar=value_dict["polars_rad"],
|
||||||
|
save_path=plot_save_path,
|
||||||
|
dynamic=("sv3d_p" in version),
|
||||||
|
)
|
||||||
|
st.image(
|
||||||
|
plot_save_path,
|
||||||
|
f"3D camera trajectory",
|
||||||
|
)
|
||||||
|
|
||||||
options["num_frames"] = T
|
options["num_frames"] = T
|
||||||
|
|
||||||
sampler, num_rows, num_cols = init_sampling(options=options)
|
sampler, num_rows, num_cols = init_sampling(options=options)
|
||||||
|
|||||||
132
scripts/sampling/configs/sv3d_p.yaml
Normal file
132
scripts/sampling/configs/sv3d_p.yaml
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
ckpt_path: checkpoints/sv3d_p_image_decoder.safetensors
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||||
|
params:
|
||||||
|
adm_in_channels: 1280
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
merge_strategy: learned_with_images
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- input_key: cond_frames_without_noise
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
params:
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
disable_encoder_autocast: True
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
is_ae: True
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
- input_key: cond_aug
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: polars_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
- input_key: azimuths_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Decoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||||
|
params:
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||||
|
params:
|
||||||
|
sigma_max: 700.0
|
||||||
|
|
||||||
|
guider_config:
|
||||||
|
target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider
|
||||||
|
params:
|
||||||
|
max_scale: 2.5
|
||||||
120
scripts/sampling/configs/sv3d_u.yaml
Normal file
120
scripts/sampling/configs/sv3d_u.yaml
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
ckpt_path: checkpoints/sv3d_u_image_decoder.safetensors
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||||
|
params:
|
||||||
|
adm_in_channels: 256
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
merge_strategy: learned_with_images
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: cond_frames_without_noise
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
params:
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
disable_encoder_autocast: True
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
is_ae: True
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
- input_key: cond_aug
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Decoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||||
|
params:
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||||
|
params:
|
||||||
|
sigma_max: 700.0
|
||||||
|
|
||||||
|
guider_config:
|
||||||
|
target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider
|
||||||
|
params:
|
||||||
|
max_scale: 2.5
|
||||||
@@ -1,27 +1,29 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
|
||||||
import cv2
|
import cv2
|
||||||
|
import imageio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from fire import Fire
|
from fire import Fire
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.transforms import ToTensor
|
from rembg import remove
|
||||||
|
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||||
from scripts.util.detection.nsfw_and_watermark_dectection import \
|
|
||||||
DeepFloydDataFiltering
|
|
||||||
from sgm.inference.helpers import embed_watermark
|
from sgm.inference.helpers import embed_watermark
|
||||||
from sgm.util import default, instantiate_from_config
|
from sgm.util import default, instantiate_from_config
|
||||||
|
from torchvision.transforms import ToTensor
|
||||||
|
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
|
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
|
||||||
num_frames: Optional[int] = None,
|
num_frames: Optional[int] = None, # 21 for SV3D
|
||||||
num_steps: Optional[int] = None,
|
num_steps: Optional[int] = None,
|
||||||
version: str = "svd",
|
version: str = "svd",
|
||||||
fps_id: int = 6,
|
fps_id: int = 6,
|
||||||
@@ -31,6 +33,10 @@ def sample(
|
|||||||
decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
output_folder: Optional[str] = None,
|
output_folder: Optional[str] = None,
|
||||||
|
elevations_deg: Optional[float | List[float]] = 10.0, # For SV3D
|
||||||
|
azimuths_deg: Optional[float | List[float]] = None, # For SV3D
|
||||||
|
image_frame_ratio: Optional[float] = None,
|
||||||
|
verbose: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
|
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
|
||||||
@@ -61,6 +67,24 @@ def sample(
|
|||||||
output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/"
|
output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/"
|
||||||
)
|
)
|
||||||
model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml"
|
model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml"
|
||||||
|
elif version == "sv3d_u":
|
||||||
|
num_frames = 21
|
||||||
|
num_steps = default(num_steps, 50)
|
||||||
|
output_folder = default(output_folder, "outputs/simple_video_sample/sv3d_u/")
|
||||||
|
model_config = "scripts/sampling/configs/sv3d_u.yaml"
|
||||||
|
cond_aug = 1e-5
|
||||||
|
elif version == "sv3d_p":
|
||||||
|
num_frames = 21
|
||||||
|
num_steps = default(num_steps, 50)
|
||||||
|
output_folder = default(output_folder, "outputs/simple_video_sample/sv3d_p/")
|
||||||
|
model_config = "scripts/sampling/configs/sv3d_p.yaml"
|
||||||
|
cond_aug = 1e-5
|
||||||
|
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
||||||
|
elevations_deg = [elevations_deg] * num_frames
|
||||||
|
polars_rad = [np.deg2rad(90 - e) for e in elevations_deg]
|
||||||
|
if azimuths_deg is None:
|
||||||
|
azimuths_deg = np.linspace(0, 360, num_frames + 1)[1:] % 360
|
||||||
|
azimuths_rad = [np.deg2rad(a) for a in azimuths_deg]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Version {version} does not exist.")
|
raise ValueError(f"Version {version} does not exist.")
|
||||||
|
|
||||||
@@ -69,6 +93,7 @@ def sample(
|
|||||||
device,
|
device,
|
||||||
num_frames,
|
num_frames,
|
||||||
num_steps,
|
num_steps,
|
||||||
|
verbose,
|
||||||
)
|
)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
@@ -93,20 +118,56 @@ def sample(
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
for input_img_path in all_img_paths:
|
for input_img_path in all_img_paths:
|
||||||
with Image.open(input_img_path) as image:
|
if "sv3d" in version:
|
||||||
|
image = Image.open(input_img_path)
|
||||||
if image.mode == "RGBA":
|
if image.mode == "RGBA":
|
||||||
image = image.convert("RGB")
|
pass
|
||||||
w, h = image.size
|
else:
|
||||||
|
# remove bg
|
||||||
|
image.thumbnail([768, 768], Image.Resampling.LANCZOS)
|
||||||
|
image = remove(image.convert("RGBA"), alpha_matting=True)
|
||||||
|
|
||||||
if h % 64 != 0 or w % 64 != 0:
|
# resize object in frame
|
||||||
width, height = map(lambda x: x - x % 64, (w, h))
|
image_arr = np.array(image)
|
||||||
image = image.resize((width, height))
|
in_w, in_h = image_arr.shape[:2]
|
||||||
print(
|
ret, mask = cv2.threshold(
|
||||||
f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
|
np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY
|
||||||
)
|
)
|
||||||
|
x, y, w, h = cv2.boundingRect(mask)
|
||||||
|
max_size = max(w, h)
|
||||||
|
side_len = (
|
||||||
|
int(max_size / image_frame_ratio)
|
||||||
|
if image_frame_ratio is not None
|
||||||
|
else in_w
|
||||||
|
)
|
||||||
|
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
|
||||||
|
center = side_len // 2
|
||||||
|
padded_image[
|
||||||
|
center - h // 2 : center - h // 2 + h,
|
||||||
|
center - w // 2 : center - w // 2 + w,
|
||||||
|
] = image_arr[y : y + h, x : x + w]
|
||||||
|
# resize frame to 576x576
|
||||||
|
rgba = Image.fromarray(padded_image).resize((576, 576), Image.LANCZOS)
|
||||||
|
# white bg
|
||||||
|
rgba_arr = np.array(rgba) / 255.0
|
||||||
|
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
|
||||||
|
input_image = Image.fromarray((rgb * 255).astype(np.uint8))
|
||||||
|
|
||||||
image = ToTensor()(image)
|
else:
|
||||||
image = image * 2.0 - 1.0
|
with Image.open(input_img_path) as image:
|
||||||
|
if image.mode == "RGBA":
|
||||||
|
input_image = image.convert("RGB")
|
||||||
|
w, h = image.size
|
||||||
|
|
||||||
|
if h % 64 != 0 or w % 64 != 0:
|
||||||
|
width, height = map(lambda x: x - x % 64, (w, h))
|
||||||
|
input_image = input_image.resize((width, height))
|
||||||
|
print(
|
||||||
|
f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
|
||||||
|
)
|
||||||
|
|
||||||
|
image = ToTensor()(input_image)
|
||||||
|
image = image * 2.0 - 1.0
|
||||||
|
|
||||||
image = image.unsqueeze(0).to(device)
|
image = image.unsqueeze(0).to(device)
|
||||||
H, W = image.shape[2:]
|
H, W = image.shape[2:]
|
||||||
@@ -114,10 +175,14 @@ def sample(
|
|||||||
F = 8
|
F = 8
|
||||||
C = 4
|
C = 4
|
||||||
shape = (num_frames, C, H // F, W // F)
|
shape = (num_frames, C, H // F, W // F)
|
||||||
if (H, W) != (576, 1024):
|
if (H, W) != (576, 1024) and "sv3d" not in version:
|
||||||
print(
|
print(
|
||||||
"WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
|
"WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
|
||||||
)
|
)
|
||||||
|
if (H, W) != (576, 576) and "sv3d" in version:
|
||||||
|
print(
|
||||||
|
"WARNING: The conditioning frame you provided is not 576x576. This leads to suboptimal performance as model was only trained on 576x576."
|
||||||
|
)
|
||||||
if motion_bucket_id > 255:
|
if motion_bucket_id > 255:
|
||||||
print(
|
print(
|
||||||
"WARNING: High motion bucket! This may lead to suboptimal performance."
|
"WARNING: High motion bucket! This may lead to suboptimal performance."
|
||||||
@@ -130,12 +195,14 @@ def sample(
|
|||||||
print("WARNING: Large fps value! This may lead to suboptimal performance.")
|
print("WARNING: Large fps value! This may lead to suboptimal performance.")
|
||||||
|
|
||||||
value_dict = {}
|
value_dict = {}
|
||||||
|
value_dict["cond_frames_without_noise"] = image
|
||||||
value_dict["motion_bucket_id"] = motion_bucket_id
|
value_dict["motion_bucket_id"] = motion_bucket_id
|
||||||
value_dict["fps_id"] = fps_id
|
value_dict["fps_id"] = fps_id
|
||||||
value_dict["cond_aug"] = cond_aug
|
value_dict["cond_aug"] = cond_aug
|
||||||
value_dict["cond_frames_without_noise"] = image
|
|
||||||
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
|
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
|
||||||
value_dict["cond_aug"] = cond_aug
|
if "sv3d_p" in version:
|
||||||
|
value_dict["polars_rad"] = polars_rad
|
||||||
|
value_dict["azimuths_rad"] = azimuths_rad
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.autocast(device):
|
with torch.autocast(device):
|
||||||
@@ -177,16 +244,15 @@ def sample(
|
|||||||
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
|
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
|
||||||
model.en_and_decode_n_samples_a_time = decoding_t
|
model.en_and_decode_n_samples_a_time = decoding_t
|
||||||
samples_x = model.decode_first_stage(samples_z)
|
samples_x = model.decode_first_stage(samples_z)
|
||||||
|
if "sv3d" in version:
|
||||||
|
samples_x[-1:] = value_dict["cond_frames_without_noise"]
|
||||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
|
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
|
||||||
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
|
|
||||||
writer = cv2.VideoWriter(
|
imageio.imwrite(
|
||||||
video_path,
|
os.path.join(output_folder, f"{base_count:06d}.jpg"), input_image
|
||||||
cv2.VideoWriter_fourcc(*"MP4V"),
|
|
||||||
fps_id + 1,
|
|
||||||
(samples.shape[-1], samples.shape[-2]),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
samples = embed_watermark(samples)
|
samples = embed_watermark(samples)
|
||||||
@@ -197,10 +263,8 @@ def sample(
|
|||||||
.numpy()
|
.numpy()
|
||||||
.astype(np.uint8)
|
.astype(np.uint8)
|
||||||
)
|
)
|
||||||
for frame in vid:
|
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
|
||||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
imageio.mimwrite(video_path, vid)
|
||||||
writer.write(frame)
|
|
||||||
writer.release()
|
|
||||||
|
|
||||||
|
|
||||||
def get_unique_embedder_keys_from_conditioner(conditioner):
|
def get_unique_embedder_keys_from_conditioner(conditioner):
|
||||||
@@ -230,12 +294,10 @@ def get_batch(keys, value_dict, N, T, device):
|
|||||||
"1 -> b",
|
"1 -> b",
|
||||||
b=math.prod(N),
|
b=math.prod(N),
|
||||||
)
|
)
|
||||||
elif key == "cond_frames":
|
elif key == "cond_frames" or key == "cond_frames_without_noise":
|
||||||
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
|
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0])
|
||||||
elif key == "cond_frames_without_noise":
|
elif key == "polars_rad" or key == "azimuths_rad":
|
||||||
batch[key] = repeat(
|
batch[key] = torch.tensor(value_dict[key]).to(device).repeat(N[0])
|
||||||
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
batch[key] = value_dict[key]
|
batch[key] = value_dict[key]
|
||||||
|
|
||||||
@@ -253,6 +315,7 @@ def load_model(
|
|||||||
device: str,
|
device: str,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
num_steps: int,
|
num_steps: int,
|
||||||
|
verbose: bool = False,
|
||||||
):
|
):
|
||||||
config = OmegaConf.load(config)
|
config = OmegaConf.load(config)
|
||||||
if device == "cuda":
|
if device == "cuda":
|
||||||
@@ -260,6 +323,7 @@ def load_model(
|
|||||||
0
|
0
|
||||||
].params.open_clip_embedding_config.params.init_device = device
|
].params.open_clip_embedding_config.params.init_device = device
|
||||||
|
|
||||||
|
config.model.params.sampler_config.params.verbose = verbose
|
||||||
config.model.params.sampler_config.params.num_steps = num_steps
|
config.model.params.sampler_config.params.num_steps = num_steps
|
||||||
config.model.params.sampler_config.params.guider_config.params.num_frames = (
|
config.model.params.sampler_config.params.guider_config.params.num_frames = (
|
||||||
num_frames
|
num_frames
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
@@ -97,3 +97,35 @@ class LinearPredictionGuider(Guider):
|
|||||||
assert c[k] == uc[k]
|
assert c[k] == uc[k]
|
||||||
c_out[k] = c[k]
|
c_out[k] = c[k]
|
||||||
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
||||||
|
|
||||||
|
|
||||||
|
class TrianglePredictionGuider(LinearPredictionGuider):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_scale: float,
|
||||||
|
num_frames: int,
|
||||||
|
min_scale: float = 1.0,
|
||||||
|
period: float | List[float] = 1.0,
|
||||||
|
period_fusing: Literal["mean", "multiply", "max"] = "max",
|
||||||
|
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
|
||||||
|
values = torch.linspace(0, 1, num_frames)
|
||||||
|
# Constructs a triangle wave
|
||||||
|
if isinstance(period, float):
|
||||||
|
period = [period]
|
||||||
|
|
||||||
|
scales = []
|
||||||
|
for p in period:
|
||||||
|
scales.append(self.triangle_wave(values, p))
|
||||||
|
|
||||||
|
if period_fusing == "mean":
|
||||||
|
scale = sum(scales) / len(period)
|
||||||
|
elif period_fusing == "multiply":
|
||||||
|
scale = torch.prod(torch.stack(scales), dim=0)
|
||||||
|
elif period_fusing == "max":
|
||||||
|
scale = torch.max(torch.stack(scales), dim=0).values
|
||||||
|
self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0)
|
||||||
|
|
||||||
|
def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor:
|
||||||
|
return 2 * (values / period - torch.floor(values / period + 0.5)).abs()
|
||||||
|
|||||||
Reference in New Issue
Block a user