Stable Video Diffusion

This commit is contained in:
Tim Dockhorn
2023-11-21 10:40:21 -08:00
parent 477d8b9a77
commit 059d8e9cd9
59 changed files with 5463 additions and 1691 deletions

119
README.md
View File

@@ -4,26 +4,48 @@
## News
**November 21, 2023**
- We are releasing Stable Video Diffusion, an image-to-video model, for research purposes:
- [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid): This model was trained to generate 14
frames at resolution 576x1024 given a context frame of the same size.
We use the standard image encoder from SD 2.1, but replace the decoder with a temporally-aware `deflickering decoder`.
- [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt): Same architecture as `SVD` but finetuned
for 25 frame generation.
- We provide a streamlit demo `scripts/demo/video_sampling.py` and a standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models.
- Alongside the model, we release a [technical report](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets).
![tile](assets/tile.gif)
**July 26, 2023**
- We are releasing two new open models with a permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file hashes):
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version over `SDXL-base-0.9`.
- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version over `SDXL-refiner-0.9`.
- We are releasing two new open models with a
permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file
hashes):
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version
over `SDXL-base-0.9`.
- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version
over `SDXL-refiner-0.9`.
![sample2](assets/001_with_eval.png)
**July 4, 2023**
- A technical report on SDXL is now available [here](https://arxiv.org/abs/2307.01952).
**June 22, 2023**
- We are releasing two new diffusion models for research purposes:
- `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip) and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses the OpenCLIP model.
- `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
- `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The
base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip)
and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses
the OpenCLIP model.
- `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is
not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
If you would like to access these models for your research, please apply using one of the following links:
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
This means that you can apply for any of the two links - and if you are granted - you can access both.
Please log in to your Hugging Face Account with your organization email to request access.
**We plan to do a full release soon (July).**
@@ -32,21 +54,32 @@ Please log in to your Hugging Face Account with your organization email to reque
### General Philosophy
Modularity is king. This repo implements a config-driven approach where we build and combine submodules by calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
Modularity is king. This repo implements a config-driven approach where we build and combine submodules by
calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
### Changelog from the old `ldm` codebase
For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`, now `DiffusionEngine`) has been cleaned up:
For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other
training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`,
now `DiffusionEngine`) has been cleaned up:
- No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial conditionings, and all combinations thereof) in a single class: `GeneralConditioner`, see `sgm/modules/encoders/modules.py`.
- No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial
conditionings, and all combinations thereof) in a single class: `GeneralConditioner`,
see `sgm/modules/encoders/modules.py`.
- We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model.
- We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable change is probably now the option to train continuous time models):
* Discrete times models (denoisers) are simply a special case of continuous time models (denoisers); see `sgm/modules/diffusionmodules/denoiser.py`.
* The following features are now independent: weighting of the diffusion loss function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
- We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable
change is probably now the option to train continuous time models):
* Discrete times models (denoisers) are simply a special case of continuous time models (denoisers);
see `sgm/modules/diffusionmodules/denoiser.py`.
* The following features are now independent: weighting of the diffusion loss
function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the
network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during
training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
- Autoencoding models have also been cleaned up.
## Installation:
<a name="installation"></a>
#### 1. Clone the repo
@@ -60,21 +93,10 @@ cd generative-models
This is assuming you have navigated to the `generative-models` root after cloning it.
**NOTE:** This is tested under `python3.8` and `python3.10`. For other python versions, you might encounter version conflicts.
**PyTorch 1.13**
```shell
# install required packages from pypi
python3 -m venv .pt13
source .pt13/bin/activate
pip3 install -r requirements/pt13.txt
```
**NOTE:** This is tested under `python3.10`. For other python versions, you might encounter version conflicts.
**PyTorch 2.0**
```shell
# install required packages from pypi
python3 -m venv .pt2
@@ -82,7 +104,6 @@ source .pt2/bin/activate
pip3 install -r requirements/pt2.txt
```
#### 3. Install `sgm`
```shell
@@ -114,8 +135,10 @@ depending on your use case and PyTorch version, manually.
## Inference
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`.
We provide file hashes for the complete file as well as for only the saved tensors in the file (see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling
in `scripts/demo/sampling.py`.
We provide file hashes for the complete file as well as for only the saved tensors in the file (
see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
The following models are currently supported:
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
@@ -136,19 +159,20 @@ The following models are currently supported:
**Weights for SDXL**:
**SDXL-1.0:**
The weights of SDXL-1.0 are available (subject to a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here:
The weights of SDXL-1.0 are available (subject to
a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here:
- base model: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/
- refiner model: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/
**SDXL-0.9:**
The weights of SDXL-0.9 are available and subject to a [research license](model_licenses/LICENSE-SDXL0.9).
If you would like to access these models for your research, please apply using one of the following links:
[SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
[SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
This means that you can apply for any of the two links - and if you are granted - you can access both.
Please log in to your Hugging Face Account with your organization email to request access.
After obtaining the weights, place them into `checkpoints/`.
Next, start the demo using
@@ -166,6 +190,7 @@ not the same as in previous Stable Diffusion 1.x/2.x versions.
To run the script you need to either have a working installation as above or
try an _experimental_ import using only a minimal amount of packages:
```bash
python -m venv .detect
source .detect/bin/activate
@@ -177,6 +202,7 @@ pip install --no-deps invisible-watermark
To run the script you need to have a working installation as above. The script
is then useable in the following ways (don't forget to activate your
virtual environment beforehand, e.g. `source .pt1/bin/activate`):
```bash
# test a single file
python scripts/demo/detect.py <your filename here>
@@ -203,11 +229,21 @@ run
python main.py --base configs/example_training/toy/mnist_cond.yaml
```
**NOTE 1:** Using the non-toy-dataset configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml` and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the used dataset (which is expected to stored in tar-file in the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search for comments containing `USER:` in the respective config.
**NOTE 1:** Using the non-toy-dataset
configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml`
and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the
used dataset (which is expected to stored in tar-file in
the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search
for comments containing `USER:` in the respective config.
**NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`, only `pytorch1.13` is supported.
**NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for
autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`,
only `pytorch1.13` is supported.
**NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done for the provided text-to-image configs.
**NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires
retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing
the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done
for the provided text-to-image configs.
### Building New Diffusion Models
@@ -216,7 +252,8 @@ python main.py --base configs/example_training/toy/mnist_cond.yaml
The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of
different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model.
All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free
guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for text-conditioning or `cls` for class-conditioning.
guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for
text-conditioning or `cls` for class-conditioning.
When computing conditionings, the embedder will get `batch[input_key]` as input.
We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated
appropriately.
@@ -229,7 +266,8 @@ enough as we plan to experiment with transformer-based diffusion backbones.
#### Loss
The loss is configured through `loss_config`. For standard diffusion model training, you will have to set `sigma_sampler_config`.
The loss is configured through `loss_config`. For standard diffusion model training, you will have to
set `sigma_sampler_config`.
#### Sampler config
@@ -239,8 +277,9 @@ guidance.
### Dataset Handling
For large scale training we recommend using the data pipelines from our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement and automatically included when following the steps from the [Installation section](#installation).
For large scale training we recommend using the data pipelines from
our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement
and automatically included when following the steps from the [Installation section](#installation).
Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
data keys/values,
e.g.,

BIN
assets/test_image.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 482 KiB

BIN
assets/tile.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 MiB

View File

@@ -29,25 +29,14 @@ model:
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4 ]
ch_mult: [1, 2, 4]
num_res_blocks: 4
attn_resolutions: [ ]
attn_resolutions: []
dropout: 0.0
decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params:
attn_type: none
double_z: False
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4 ]
num_res_blocks: 4
attn_resolutions: [ ]
dropout: 0.0
params: ${model.params.encoder_config.params}
data:
target: sgm.data.dataset.StableDataModuleFromConfig
@@ -55,18 +44,18 @@ data:
train:
datapipeline:
urls:
- "DATA-PATH"
- DATA-PATH
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000
decoders:
- "pil"
- pil
postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: 'jpg'
key: jpg
transforms:
- target: torchvision.transforms.Resize
params:

View File

@@ -0,0 +1,105 @@
model:
base_learning_rate: 4.5e-6
target: sgm.models.autoencoder.AutoencodingEngine
params:
input_key: jpg
monitor: val/loss/rec
disc_start_iter: 0
encoder_config:
target: sgm.modules.diffusionmodules.model.Encoder
params:
attn_type: vanilla-xformers
double_z: true
z_channels: 8
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params: ${model.params.encoder_config.params}
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
loss_config:
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
params:
perceptual_weight: 0.25
disc_start: 20001
disc_weight: 0.5
learn_logvar: True
regularization_weights:
kl_loss: 1.0
data:
target: sgm.data.dataset.StableDataModuleFromConfig
params:
train:
datapipeline:
urls:
- DATA-PATH
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000
decoders:
- pil
postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: jpg
transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.ToTensor
- target: sdata.mappers.Rescaler
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
params:
h_key: height
w_key: width
loader:
batch_size: 8
num_workers: 4
lightning:
strategy:
target: pytorch_lightning.strategies.DDPStrategy
params:
find_unused_parameters: True
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 50000
image_logger:
target: main.ImageLogger
params:
enable_autocast: False
batch_frequency: 1000
max_images: 8
increase_log_steps: True
trainer:
devices: 0,
limit_val_batches: 50
benchmark: True
accumulate_grad_batches: 1
val_check_interval: 10000

View File

@@ -21,8 +21,6 @@ model:
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
@@ -32,7 +30,6 @@ model:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
in_channels: 4
out_channels: 4
model_channels: 256
@@ -42,7 +39,6 @@ model:
num_head_channels: 64
num_classes: sequential
adm_in_channels: 1024
use_spatial_transformer: true
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
@@ -51,32 +47,31 @@ model:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: True
input_key: cls
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder
params:
add_sequence_dim: True # will be used through crossattn then
add_sequence_dim: True
embed_dim: 1024
n_classes: 1000
# vector cond
- is_trainable: False
ucg_rate: 0.2
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
outdim: 256
- is_trainable: False
input_key: crop_coords_top_left
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
target: sgm.models.autoencoder.AutoencoderKL
params:
ckpt_path: CKPT_PATH
embed_dim: 4
@@ -99,6 +94,8 @@ model:
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
@@ -127,18 +124,18 @@ data:
datapipeline:
urls:
# USER: adapt this path the root of your custom dataset
- "DATA_PATH"
- DATA_PATH
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
decoders:
- "pil"
- pil
postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: 'jpg' # USER: you might wanna adapt this for your custom dataset
key: jpg # USER: you might wanna adapt this for your custom dataset
transforms:
- target: torchvision.transforms.Resize
params:

View File

@@ -5,10 +5,6 @@ model:
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
params:
sigma_data: 1.0
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
params:
@@ -17,7 +13,6 @@ model:
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
in_channels: 3
out_channels: 3
model_channels: 32
@@ -46,6 +41,10 @@ model:
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
params:
sigma_data: 1.0
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling

View File

@@ -5,10 +5,6 @@ model:
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
params:
sigma_data: 1.0
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
params:
@@ -17,7 +13,6 @@ model:
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
in_channels: 1
out_channels: 1
model_channels: 32
@@ -32,6 +27,10 @@ model:
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
params:
sigma_data: 1.0
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling

View File

@@ -5,10 +5,6 @@ model:
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
params:
sigma_data: 1.0
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
params:
@@ -17,13 +13,12 @@ model:
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
in_channels: 1
out_channels: 1
model_channels: 32
attention_resolutions: [ ]
attention_resolutions: []
num_res_blocks: 4
channel_mult: [ 1, 2, 2 ]
channel_mult: [1, 2, 2]
num_head_channels: 32
num_classes: sequential
adm_in_channels: 128
@@ -33,7 +28,7 @@ model:
params:
emb_models:
- is_trainable: True
input_key: "cls"
input_key: cls
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder
params:
@@ -46,6 +41,10 @@ model:
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
params:
sigma_data: 1.0
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling

View File

@@ -7,8 +7,6 @@ model:
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
discretization_config:
@@ -17,13 +15,12 @@ model:
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
in_channels: 1
out_channels: 1
model_channels: 32
attention_resolutions: [ ]
attention_resolutions: []
num_res_blocks: 4
channel_mult: [ 1, 2, 2 ]
channel_mult: [1, 2, 2]
num_head_channels: 32
num_classes: sequential
adm_in_channels: 128
@@ -33,7 +30,7 @@ model:
params:
emb_models:
- is_trainable: True
input_key: "cls"
input_key: cls
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder
params:
@@ -46,6 +43,8 @@ model:
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:

View File

@@ -5,10 +5,6 @@ model:
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
params:
sigma_data: 1.0
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
params:
@@ -17,7 +13,6 @@ model:
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
in_channels: 1
out_channels: 1
model_channels: 32
@@ -25,7 +20,7 @@ model:
num_res_blocks: 4
channel_mult: [1, 2, 2]
num_head_channels: 32
num_classes: "sequential"
num_classes: sequential
adm_in_channels: 128
conditioner_config:
@@ -33,7 +28,7 @@ model:
params:
emb_models:
- is_trainable: True
input_key: "cls"
input_key: cls
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder
params:
@@ -46,6 +41,11 @@ model:
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
loss_type: l1
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
params:
sigma_data: 1.0
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
@@ -62,11 +62,6 @@ model:
params:
scale: 3.0
loss_config:
target: sgm.modules.diffusionmodules.StandardDiffusionLoss
params:
type: l1
data:
target: sgm.data.mnist.MNISTLoader
params:

View File

@@ -7,10 +7,6 @@ model:
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
params:
sigma_data: 1.0
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
params:
@@ -19,7 +15,6 @@ model:
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
in_channels: 1
out_channels: 1
model_channels: 32
@@ -48,6 +43,10 @@ model:
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
params:
sigma_data: 1.0
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling

View File

@@ -10,19 +10,17 @@ model:
scheduler_config:
target: sgm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ]
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
warm_up_steps: [10000]
cycle_lengths: [10000000000000]
f_start: [1.e-6]
f_max: [1.]
f_min: [1.]
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
@@ -32,18 +30,16 @@ model:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 1, 2, 4 ]
attention_resolutions: [1, 2, 4]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
num_classes: sequential
adm_in_channels: 1792
num_heads: 1
use_spatial_transformer: true
transformer_depth: 1
context_dim: 768
spatial_transformer_attn_type: softmax-xformers
@@ -52,7 +48,6 @@ model:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: True
input_key: txt
ucg_rate: 0.1
@@ -60,23 +55,23 @@ model:
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
always_return_pooled: True
# vector cond
- is_trainable: False
ucg_rate: 0.1
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
outdim: 256
- is_trainable: False
input_key: crop_coords_top_left
ucg_rate: 0.1
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
target: sgm.models.autoencoder.AutoencoderKL
params:
ckpt_path: CKPT_PATH
embed_dim: 4
@@ -99,6 +94,8 @@ model:
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
@@ -127,18 +124,18 @@ data:
datapipeline:
urls:
# USER: adapt this path the root of your custom dataset
- "DATA_PATH"
- DATA_PATH
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
decoders:
- "pil"
- pil
postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: 'jpg' # USER: you might wanna adapt this for your custom dataset
key: jpg # USER: you might wanna adapt this for your custom dataset
transforms:
- target: torchvision.transforms.Resize
params:

View File

@@ -10,19 +10,17 @@ model:
scheduler_config:
target: sgm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ]
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
warm_up_steps: [10000]
cycle_lengths: [10000000000000]
f_start: [1.e-6]
f_max: [1.]
f_min: [1.]
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
@@ -32,18 +30,16 @@ model:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 1, 2, 4 ]
attention_resolutions: [1, 2, 4]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
num_classes: sequential
adm_in_channels: 1792
num_heads: 1
use_spatial_transformer: true
transformer_depth: 1
context_dim: 768
spatial_transformer_attn_type: softmax-xformers
@@ -52,30 +48,30 @@ model:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: True
input_key: txt
ucg_rate: 0.1
legacy_ucg_value: ""
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
always_return_pooled: True
# vector cond
- is_trainable: False
ucg_rate: 0.1
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
outdim: 256
- is_trainable: False
input_key: crop_coords_top_left
ucg_rate: 0.1
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
target: sgm.models.autoencoder.AutoencoderKL
params:
ckpt_path: CKPT_PATH
embed_dim: 4
@@ -88,9 +84,9 @@ model:
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4, 4 ]
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: [ ]
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
@@ -98,6 +94,8 @@ model:
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
@@ -126,19 +124,19 @@ data:
datapipeline:
urls:
# USER: adapt this path the root of your custom dataset
- "DATA_PATH"
- DATA_PATH
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000
decoders:
- "pil"
- pil
postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: 'jpg' # USER: you might wanna adapt this for your custom dataset
key: jpg # USER: you might wanna adapt this for your custom dataset
transforms:
- target: torchvision.transforms.Resize
params:

View File

@@ -9,8 +9,6 @@ model:
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
@@ -20,7 +18,6 @@ model:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
in_channels: 4
out_channels: 4
model_channels: 320
@@ -28,17 +25,14 @@ model:
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
@@ -47,7 +41,7 @@ model:
layer: penultimate
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
target: sgm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss

View File

@@ -9,8 +9,6 @@ model:
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.VWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
discretization_config:
@@ -20,7 +18,6 @@ model:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
in_channels: 4
out_channels: 4
model_channels: 320
@@ -28,17 +25,14 @@ model:
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
@@ -47,7 +41,7 @@ model:
layer: penultimate
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
target: sgm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss

View File

@@ -9,8 +9,6 @@ model:
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
@@ -29,25 +27,22 @@ model:
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
transformer_depth: [1, 2, 10]
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
@@ -58,27 +53,27 @@ model:
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
outdim: 256
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
outdim: 256
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
target: sgm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss

View File

@@ -9,8 +9,6 @@ model:
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
@@ -29,18 +27,15 @@ model:
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 4
context_dim: [1280, 1280, 1280, 1280] # 1280
context_dim: [1280, 1280, 1280, 1280]
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
@@ -51,27 +46,27 @@ model:
freeze: True
layer: penultimate
always_return_pooled: True
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
outdim: 256
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
outdim: 256
- is_trainable: False
input_key: aesthetic_score
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by one
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
target: sgm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss

131
configs/inference/svd.yaml Normal file
View File

@@ -0,0 +1,131 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
network_config:
target: sgm.modules.diffusionmodules.video_model.VideoUNet
params:
adm_in_channels: 768
num_classes: sequential
use_checkpoint: True
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2, 1]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
extra_ff_mix_layer: True
use_spatial_context: True
merge_strategy: learned_with_images
video_kernel_size: [3, 1, 1]
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: cond_frames_without_noise
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
params:
n_cond_frames: 1
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True
- input_key: fps_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: motion_bucket_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: cond_frames
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
disable_encoder_autocast: True
n_cond_frames: 1
n_copies: 1
is_ae: True
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
- input_key: cond_aug
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencodingEngine
params:
loss_config:
target: torch.nn.Identity
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: sgm.modules.diffusionmodules.model.Encoder
params:
attn_type: vanilla
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
decoder_config:
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
params:
attn_type: vanilla
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
video_kernel_size: [3, 1, 1]

View File

@@ -0,0 +1,114 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
network_config:
target: sgm.modules.diffusionmodules.video_model.VideoUNet
params:
adm_in_channels: 768
num_classes: sequential
use_checkpoint: True
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2, 1]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
extra_ff_mix_layer: True
use_spatial_context: True
merge_strategy: learned_with_images
video_kernel_size: [3, 1, 1]
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: cond_frames_without_noise
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
params:
n_cond_frames: 1
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True
- input_key: fps_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: motion_bucket_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: cond_frames
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
disable_encoder_autocast: True
n_cond_frames: 1
n_copies: 1
is_ae: True
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
- input_key: cond_aug
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@@ -0,0 +1,31 @@
STABLE VIDEO DIFFUSION NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT
Dated: November 21, 2023
“AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
"Agreement" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein.
"Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Models output. For clarity, Derivative Works do not include the output of any Model.
“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
"Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
"Stability AI" or "we" means Stability AI Ltd.
"Software" means, collectively, Stability AIs proprietary StableCode made available under this Agreement.
“Software Products” means Software and Documentation.
By using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement.
License Rights and Redistribution.
Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AIs intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create Derivative Works of the Software Products for purposes other than commercial or production use.
b. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "Stable Video Diffusion is licensed under the Stable Video Diffusion Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS.
3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
3. Intellectual Property.
a. No trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products.
Subject to Stability AIs ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works.
If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement.
4. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Software Products. Sections 2-4 shall survive the termination of this Agreement.

View File

@@ -0,0 +1,59 @@
import torch
from sgm.modules.diffusionmodules.discretizer import Discretization
class Img2ImgDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization: Discretization, strength: float = 1.0):
self.discretization = discretization
self.strength = strength
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
class Txt2NoisyDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
"""
def __init__(
self, discretization: Discretization, strength: float = 0.0, original_steps=None
):
self.discretization = discretization
self.strength = strength
self.original_steps = original_steps
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
if self.original_steps is None:
steps = len(sigmas)
else:
steps = self.original_steps + 1
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
sigmas = sigmas[prune_index:]
print("prune index:", prune_index)
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas

View File

@@ -253,7 +253,10 @@ if __name__ == "__main__":
st.title("Stable Diffusion")
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
version_dict = VERSION2SPECS[version]
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
if st.checkbox("Load Model"):
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
else:
mode = "skip"
st.write("__________________________")
set_lowvram_mode(st.checkbox("Low vram mode", True))
@@ -269,10 +272,11 @@ if __name__ == "__main__":
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
state = init_st(version_dict, load_filter=True)
if state["msg"]:
st.info(state["msg"])
model = state["model"]
if mode != "skip":
state = init_st(version_dict, load_filter=True)
if state["msg"]:
st.info(state["msg"])
model = state["model"]
is_legacy = version_dict["is_legacy"]
@@ -333,6 +337,8 @@ if __name__ == "__main__":
filter=state.get("filter"),
stage2strength=stage2strength,
)
elif mode == "skip":
out = None
else:
raise ValueError(f"unknown mode {mode}")
if isinstance(out, (tuple, list)):

View File

@@ -1,10 +1,15 @@
import copy
import math
import os
from typing import List, Union
from glob import glob
from typing import Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as TT
from einops import rearrange, repeat
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig, OmegaConf
@@ -12,63 +17,22 @@ from PIL import Image
from safetensors.torch import load_file as load_safetensors
from torch import autocast
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.utils import make_grid, save_image
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler,
HeunEDMSampler,
LinearMultistepSampler,
)
from sgm.util import append_dims, instantiate_from_config
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor):
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, C, H, W) in range [0, 1]
Returns:
same as input but watermarked
"""
# watermarking libary expects input as cv2 BGR format
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
).to(image.device)
image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
return image
# A fixed 48-bit message that was choosen at random
# WATERMARK_MESSAGE = 0xB3EC907BB19E
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
from scripts.demo.discretization import (Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper)
from scripts.util.detection.nsfw_and_watermark_dectection import \
DeepFloydDataFiltering
from sgm.inference.helpers import embed_watermark
from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider,
VanillaCFG)
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler,
HeunEDMSampler,
LinearMultistepSampler)
from sgm.util import append_dims, default, instantiate_from_config
@st.cache_resource()
@@ -164,11 +128,12 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
for key in keys:
if key == "txt":
if prompt is None:
prompt = st.text_input(
"Prompt", "A professional photograph of an astronaut riding a pig"
)
prompt = "A professional photograph of an astronaut riding a pig"
if negative_prompt is None:
negative_prompt = st.text_input("Negative prompt", "")
negative_prompt = ""
prompt = st.text_input("Prompt", prompt)
negative_prompt = st.text_input("Negative prompt", negative_prompt)
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
@@ -203,13 +168,35 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
value_dict["target_width"] = init_dict["target_width"]
value_dict["target_height"] = init_dict["target_height"]
if key in ["fps_id", "fps"]:
fps = st.number_input("fps", value=6, min_value=1)
value_dict["fps"] = fps
value_dict["fps_id"] = fps - 1
if key == "motion_bucket_id":
mb_id = st.number_input("motion bucket id", 0, 511, value=127)
value_dict["motion_bucket_id"] = mb_id
if key == "pool_image":
st.text("Image for pool conditioning")
image = load_img(
key="pool_image_input",
size=224,
center_crop=True,
)
if image is None:
st.info("Need an image here")
image = torch.zeros(1, 3, 224, 224)
value_dict["pool_image"] = image
return value_dict
def perform_save_locally(save_path, samples):
os.makedirs(os.path.join(save_path), exist_ok=True)
base_count = len(os.listdir(os.path.join(save_path)))
samples = embed_watemark(samples)
samples = embed_watermark(samples)
for sample in samples:
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save(
@@ -228,95 +215,99 @@ def init_save_locally(_dir, init_value: bool = False):
return save_locally, save_path
class Img2ImgDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 1.0):
self.discretization = discretization
self.strength = strength
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
class Txt2NoisyDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
self.discretization = discretization
self.strength = strength
self.original_steps = original_steps
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
if self.original_steps is None:
steps = len(sigmas)
else:
steps = self.original_steps + 1
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
sigmas = sigmas[prune_index:]
print("prune index:", prune_index)
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def get_guider(key):
def get_guider(options, key):
guider = st.sidebar.selectbox(
f"Discretization #{key}",
[
"VanillaCFG",
"IdentityGuider",
"LinearPredictionGuider",
],
options.get("guider", 0),
)
additional_guider_kwargs = options.pop("additional_guider_kwargs", {})
if guider == "IdentityGuider":
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
elif guider == "VanillaCFG":
scale = st.number_input(
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
scale_schedule = st.sidebar.selectbox(
f"Scale schedule #{key}",
["Identity", "Oscillating"],
)
thresholder = st.sidebar.selectbox(
f"Thresholder #{key}",
[
"None",
],
)
if scale_schedule == "Identity":
scale = st.number_input(
f"cfg-scale #{key}",
value=options.get("cfg", 5.0),
min_value=0.0,
)
if thresholder == "None":
dyn_thresh_config = {
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
scale_schedule_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentitySchedule",
"params": {"scale": scale},
}
elif scale_schedule == "Oscillating":
small_scale = st.number_input(
f"small cfg-scale #{key}",
value=4.0,
min_value=0.0,
)
large_scale = st.number_input(
f"large cfg-scale #{key}",
value=16.0,
min_value=0.0,
)
sigma_cutoff = st.number_input(
f"sigma cutoff #{key}",
value=1.0,
min_value=0.0,
)
scale_schedule_config = {
"target": "sgm.modules.diffusionmodules.guiders.OscillatingSchedule",
"params": {
"small_scale": small_scale,
"large_scale": large_scale,
"sigma_cutoff": sigma_cutoff,
},
}
else:
raise NotImplementedError
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
"params": {
"scale_schedule_config": scale_schedule_config,
**additional_guider_kwargs,
},
}
elif guider == "LinearPredictionGuider":
max_scale = st.number_input(
f"max-cfg-scale #{key}",
value=options.get("cfg", 1.5),
min_value=1.0,
)
min_scale = st.number_input(
f"min guidance scale",
value=options.get("min_cfg", 1.0),
min_value=1.0,
max_value=10.0,
)
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider",
"params": {
"max_scale": max_scale,
"min_scale": min_scale,
"num_frames": options["num_frames"],
**additional_guider_kwargs,
},
}
else:
raise NotImplementedError
@@ -325,18 +316,21 @@ def get_guider(key):
def init_sampling(
key=1,
img2img_strength=1.0,
specify_num_samples=True,
stage2strength=None,
img2img_strength: Optional[float] = None,
specify_num_samples: bool = True,
stage2strength: Optional[float] = None,
options: Optional[Dict[str, int]] = None,
):
options = {} if options is None else options
num_rows, num_cols = 1, 1
if specify_num_samples:
num_cols = st.number_input(
f"num cols #{key}", value=2, min_value=1, max_value=10
f"num cols #{key}", value=num_cols, min_value=1, max_value=10
)
steps = st.sidebar.number_input(
f"steps #{key}", value=40, min_value=1, max_value=1000
f"steps #{key}", value=options.get("num_steps", 40), min_value=1, max_value=1000
)
sampler = st.sidebar.selectbox(
f"Sampler #{key}",
@@ -348,7 +342,7 @@ def init_sampling(
"DPMPP2MSampler",
"LinearMultistepSampler",
],
0,
options.get("sampler", 0),
)
discretization = st.sidebar.selectbox(
f"Discretization #{key}",
@@ -356,14 +350,15 @@ def init_sampling(
"LegacyDDPMDiscretization",
"EDMDiscretization",
],
options.get("discretization", 0),
)
discretization_config = get_discretization(discretization, key=key)
discretization_config = get_discretization(discretization, options=options, key=key)
guider_config = get_guider(key=key)
guider_config = get_guider(options=options, key=key)
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
if img2img_strength < 1.0:
if img2img_strength is not None:
st.warning(
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
)
@@ -377,15 +372,19 @@ def init_sampling(
return sampler, num_rows, num_cols
def get_discretization(discretization, key=1):
def get_discretization(discretization, options, key=1):
if discretization == "LegacyDDPMDiscretization":
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
}
elif discretization == "EDMDiscretization":
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146
rho = st.number_input(f"rho #{key}", value=3.0)
sigma_min = st.number_input(
f"sigma_min #{key}", value=options.get("sigma_min", 0.03)
) # 0.0292
sigma_max = st.number_input(
f"sigma_max #{key}", value=options.get("sigma_max", 14.61)
) # 14.6146
rho = st.number_input(f"rho #{key}", value=options.get("rho", 3.0))
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
"params": {
@@ -474,8 +473,8 @@ def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1
return sampler
def get_interactive_image(key=None) -> Image.Image:
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
def get_interactive_image() -> Image.Image:
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
if image is not None:
image = Image.open(image)
if not image.mode == "RGB":
@@ -483,8 +482,12 @@ def get_interactive_image(key=None) -> Image.Image:
return image
def load_img(display=True, key=None):
image = get_interactive_image(key=key)
def load_img(
display: bool = True,
size: Union[None, int, Tuple[int, int]] = None,
center_crop: bool = False,
):
image = get_interactive_image()
if image is None:
return None
if display:
@@ -492,12 +495,15 @@ def load_img(display=True, key=None):
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Lambda(lambda x: x * 2.0 - 1.0),
]
)
transform = []
if size is not None:
transform.append(transforms.Resize(size))
if center_crop:
transform.append(transforms.CenterCrop(size))
transform.append(transforms.ToTensor())
transform.append(transforms.Lambda(lambda x: 2.0 * x - 1.0))
transform = transforms.Compose(transform)
img = transform(image)[None, ...]
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
return img
@@ -518,15 +524,18 @@ def do_sample(
W,
C,
F,
force_uc_zero_embeddings: List = None,
force_uc_zero_embeddings: Optional[List] = None,
force_cond_zero_embeddings: Optional[List] = None,
batch2model_input: List = None,
return_latents=False,
filter=None,
T=None,
additional_batch_uc_fields=None,
decoding_t=None,
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
if batch2model_input is None:
batch2model_input = []
force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])
batch2model_input = default(batch2model_input, [])
additional_batch_uc_fields = default(additional_batch_uc_fields, [])
st.text("Sampling")
@@ -535,24 +544,25 @@ def do_sample(
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
num_samples = [num_samples]
if T is not None:
num_samples = [num_samples, T]
else:
num_samples = [num_samples]
load_model(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
T=T,
additional_batch_uc_fields=additional_batch_uc_fields,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
force_cond_zero_embeddings=force_cond_zero_embeddings,
)
unload_model(model.conditioner)
@@ -561,10 +571,29 @@ def do_sample(
c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
)
if k in ["crossattn", "concat"] and T is not None:
uc[k] = repeat(uc[k], "b ... -> b t ...", t=T)
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=T)
c[k] = repeat(c[k], "b ... -> b t ...", t=T)
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=T)
additional_model_inputs = {}
for k in batch2model_input:
additional_model_inputs[k] = batch[k]
if k == "image_only_indicator":
assert T is not None
if isinstance(
sampler.guider, (VanillaCFG, LinearPredictionGuider)
):
additional_model_inputs[k] = torch.zeros(
num_samples[0] * 2, num_samples[1]
).to("cuda")
else:
additional_model_inputs[k] = torch.zeros(num_samples).to(
"cuda"
)
else:
additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to("cuda")
@@ -581,6 +610,9 @@ def do_sample(
unload_model(model.denoiser)
load_model(model.first_stage_model)
model.en_and_decode_n_samples_a_time = (
decoding_t # Decode n frames at a time
)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_model(model.first_stage_model)
@@ -588,16 +620,32 @@ def do_sample(
if filter is not None:
samples = filter(samples)
grid = torch.stack([samples])
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if T is None:
grid = torch.stack([samples])
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
else:
as_vids = rearrange(samples, "(b t) c h w -> b t c h w", t=T)
for i, vid in enumerate(as_vids):
grid = rearrange(make_grid(vid, nrow=4), "c h w -> h w c")
st.image(
grid.cpu().numpy(),
f"Sample #{i} as image",
)
if return_latents:
return samples, samples_z
return samples
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
def get_batch(
keys,
value_dict: dict,
N: Union[List, ListConfig],
device: str = "cuda",
T: int = None,
additional_batch_uc_fields: List[str] = [],
):
# Hardcoded demo setups; might undergo some changes in the future
batch = {}
@@ -605,21 +653,15 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
for key in keys:
if key == "txt":
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
batch["txt"] = [value_dict["prompt"]] * math.prod(N)
batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(*N, 1)
.repeat(math.prod(N), 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
@@ -627,30 +669,67 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
)
.to(device)
.repeat(*N, 1)
.repeat(math.prod(N), 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
torch.tensor([value_dict["aesthetic_score"]])
.to(device)
.repeat(math.prod(N), 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device)
.repeat(*N, 1)
.repeat(math.prod(N), 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(*N, 1)
.repeat(math.prod(N), 1)
)
elif key == "fps":
batch[key] = (
torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
)
elif key == "fps_id":
batch[key] = (
torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
)
elif key == "motion_bucket_id":
batch[key] = (
torch.tensor([value_dict["motion_bucket_id"]])
.to(device)
.repeat(math.prod(N))
)
elif key == "pool_image":
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
device, dtype=torch.half
)
elif key == "cond_aug":
batch[key] = repeat(
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
"1 -> b",
b=math.prod(N),
)
elif key == "cond_frames":
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
elif key == "cond_frames_without_noise":
batch[key] = repeat(
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
)
else:
batch[key] = value_dict[key]
if T is not None:
batch["num_video_frames"] = T
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
elif key in additional_batch_uc_fields and key not in batch_uc:
batch_uc[key] = copy.copy(batch[key])
return batch, batch_uc
@@ -661,7 +740,8 @@ def do_img2img(
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=[],
force_uc_zero_embeddings: Optional[List] = None,
force_cond_zero_embeddings: Optional[List] = None,
additional_kwargs={},
offset_noise_level: int = 0.0,
return_latents=False,
@@ -686,6 +766,7 @@ def do_img2img(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
force_cond_zero_embeddings=force_cond_zero_embeddings,
)
unload_model(model.conditioner)
for k in c:
@@ -736,9 +817,112 @@ def do_img2img(
if filter is not None:
samples = filter(samples)
grid = embed_watemark(torch.stack([samples]))
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples
def get_resizing_factor(
desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
) -> float:
r_bound = desired_shape[1] / desired_shape[0]
aspect_r = current_shape[1] / current_shape[0]
if r_bound >= 1.0:
if aspect_r >= r_bound:
factor = min(desired_shape) / min(current_shape)
else:
if aspect_r < 1.0:
factor = max(desired_shape) / min(current_shape)
else:
factor = max(desired_shape) / max(current_shape)
else:
if aspect_r <= r_bound:
factor = min(desired_shape) / min(current_shape)
else:
if aspect_r > 1:
factor = max(desired_shape) / min(current_shape)
else:
factor = max(desired_shape) / max(current_shape)
return factor
def get_interactive_image(key=None) -> Image.Image:
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
if image is not None:
image = Image.open(image)
if not image.mode == "RGB":
image = image.convert("RGB")
return image
def load_img_for_prediction(
W: int, H: int, display=True, key=None, device="cuda"
) -> torch.Tensor:
image = get_interactive_image(key=key)
if image is None:
return None
if display:
st.image(image)
w, h = image.size
image = np.array(image).transpose(2, 0, 1)
image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0
image = image.unsqueeze(0)
rfs = get_resizing_factor((H, W), (h, w))
resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
top = (resize_size[0] - H) // 2
left = (resize_size[1] - W) // 2
image = torch.nn.functional.interpolate(
image, resize_size, mode="area", antialias=False
)
image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
if display:
numpy_img = np.transpose(image[0].numpy(), (1, 2, 0))
pil_image = Image.fromarray((numpy_img * 255).astype(np.uint8))
st.image(pil_image)
return image.to(device) * 2.0 - 1.0
def save_video_as_grid_and_mp4(
video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5
):
os.makedirs(save_path, exist_ok=True)
base_count = len(glob(os.path.join(save_path, "*.mp4")))
video_batch = rearrange(video_batch, "(b t) c h w -> b t c h w", t=T)
video_batch = embed_watermark(video_batch)
for vid in video_batch:
save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4)
video_path = os.path.join(save_path, f"{base_count:06d}.mp4")
writer = cv2.VideoWriter(
video_path,
cv2.VideoWriter_fourcc(*"MP4V"),
fps,
(vid.shape[-1], vid.shape[-2]),
)
vid = (
(rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8)
)
for frame in vid:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
writer.write(frame)
writer.release()
video_path_h264 = video_path[:-4] + "_h264.mp4"
os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}")
with open(video_path_h264, "rb") as f:
video_bytes = f.read()
st.video(video_bytes)
base_count += 1

View File

@@ -0,0 +1,200 @@
import os
from pytorch_lightning import seed_everything
from scripts.demo.streamlit_helpers import *
SAVE_PATH = "outputs/demo/vid/"
VERSION2SPECS = {
"svd": {
"T": 14,
"H": 576,
"W": 1024,
"C": 4,
"f": 8,
"config": "configs/inference/svd.yaml",
"ckpt": "checkpoints/svd.safetensors",
"options": {
"discretization": 1,
"cfg": 2.5,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
"guider": 2,
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
"num_steps": 25,
},
},
"svd_image_decoder": {
"T": 14,
"H": 576,
"W": 1024,
"C": 4,
"f": 8,
"config": "configs/inference/svd_image_decoder.yaml",
"ckpt": "checkpoints/svd_image_decoder.safetensors",
"options": {
"discretization": 1,
"cfg": 2.5,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
"guider": 2,
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
"num_steps": 25,
},
},
"svd_xt": {
"T": 25,
"H": 576,
"W": 1024,
"C": 4,
"f": 8,
"config": "configs/inference/svd.yaml",
"ckpt": "checkpoints/svd_xt.safetensors",
"options": {
"discretization": 1,
"cfg": 3.0,
"min_cfg": 1.5,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
"guider": 2,
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
"num_steps": 30,
"decoding_t": 14,
},
},
"svd_xt_image_decoder": {
"T": 25,
"H": 576,
"W": 1024,
"C": 4,
"f": 8,
"config": "configs/inference/svd_image_decoder.yaml",
"ckpt": "checkpoints/svd_xt_image_decoder.safetensors",
"options": {
"discretization": 1,
"cfg": 3.0,
"min_cfg": 1.5,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
"guider": 2,
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
"num_steps": 30,
"decoding_t": 14,
},
},
}
if __name__ == "__main__":
st.title("Stable Video Diffusion")
version = st.selectbox(
"Model Version",
[k for k in VERSION2SPECS.keys()],
0,
)
version_dict = VERSION2SPECS[version]
if st.checkbox("Load Model"):
mode = "img2vid"
else:
mode = "skip"
H = st.sidebar.number_input(
"H", value=version_dict["H"], min_value=64, max_value=2048
)
W = st.sidebar.number_input(
"W", value=version_dict["W"], min_value=64, max_value=2048
)
T = st.sidebar.number_input(
"T", value=version_dict["T"], min_value=0, max_value=128
)
C = version_dict["C"]
F = version_dict["f"]
options = version_dict["options"]
if mode != "skip":
state = init_st(version_dict, load_filter=True)
if state["msg"]:
st.info(state["msg"])
model = state["model"]
ukeys = set(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner)
)
value_dict = init_embedder_options(
ukeys,
{},
)
value_dict["image_only_indicator"] = 0
if mode == "img2vid":
img = load_img_for_prediction(W, H)
cond_aug = st.number_input(
"Conditioning augmentation:", value=0.02, min_value=0.0
)
value_dict["cond_frames_without_noise"] = img
value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img)
value_dict["cond_aug"] = cond_aug
seed = st.sidebar.number_input(
"seed", value=23, min_value=0, max_value=int(1e9)
)
seed_everything(seed)
save_locally, save_path = init_save_locally(
os.path.join(SAVE_PATH, version), init_value=True
)
options["num_frames"] = T
sampler, num_rows, num_cols = init_sampling(options=options)
num_samples = num_rows * num_cols
decoding_t = st.number_input(
"Decode t frames at a time (set small if you are low on VRAM)",
value=options.get("decoding_t", T),
min_value=1,
max_value=int(1e9),
)
if st.checkbox("Overwrite fps in mp4 generator", False):
saving_fps = st.number_input(
f"saving video at fps:", value=value_dict["fps"], min_value=1
)
else:
saving_fps = value_dict["fps"]
if st.button("Sample"):
out = do_sample(
model,
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
T=T,
batch2model_input=["num_video_frames", "image_only_indicator"],
force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None),
force_cond_zero_embeddings=options.get(
"force_cond_zero_embeddings", None
),
return_latents=False,
decoding_t=decoding_t,
)
if isinstance(out, (tuple, list)):
samples, samples_z = out
else:
samples = out
samples_z = None
if save_locally:
save_video_as_grid_and_mp4(samples, save_path, T, fps=saving_fps)

View File

@@ -0,0 +1,146 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
ckpt_path: checkpoints/svd.safetensors
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
network_config:
target: sgm.modules.diffusionmodules.video_model.VideoUNet
params:
adm_in_channels: 768
num_classes: sequential
use_checkpoint: True
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2, 1]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
extra_ff_mix_layer: True
use_spatial_context: True
merge_strategy: learned_with_images
video_kernel_size: [3, 1, 1]
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: cond_frames_without_noise
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
params:
n_cond_frames: 1
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True
- input_key: fps_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: motion_bucket_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: cond_frames
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
disable_encoder_autocast: True
n_cond_frames: 1
n_copies: 1
is_ae: True
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
- input_key: cond_aug
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencodingEngine
params:
loss_config:
target: torch.nn.Identity
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: sgm.modules.diffusionmodules.model.Encoder
params:
attn_type: vanilla
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
decoder_config:
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
params:
attn_type: vanilla
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
video_kernel_size: [3, 1, 1]
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
params:
sigma_max: 700.0
guider_config:
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
params:
max_scale: 2.5
min_scale: 1.0

View File

@@ -0,0 +1,129 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
ckpt_path: checkpoints/svd_image_decoder.safetensors
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
network_config:
target: sgm.modules.diffusionmodules.video_model.VideoUNet
params:
adm_in_channels: 768
num_classes: sequential
use_checkpoint: True
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2, 1]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
extra_ff_mix_layer: True
use_spatial_context: True
merge_strategy: learned_with_images
video_kernel_size: [3, 1, 1]
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: cond_frames_without_noise
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
params:
n_cond_frames: 1
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True
- input_key: fps_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: motion_bucket_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: cond_frames
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
disable_encoder_autocast: True
n_cond_frames: 1
n_copies: 1
is_ae: True
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
- input_key: cond_aug
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
params:
sigma_max: 700.0
guider_config:
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
params:
max_scale: 2.5
min_scale: 1.0

View File

@@ -0,0 +1,146 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
ckpt_path: checkpoints/svd_xt.safetensors
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
network_config:
target: sgm.modules.diffusionmodules.video_model.VideoUNet
params:
adm_in_channels: 768
num_classes: sequential
use_checkpoint: True
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2, 1]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
extra_ff_mix_layer: True
use_spatial_context: True
merge_strategy: learned_with_images
video_kernel_size: [3, 1, 1]
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: cond_frames_without_noise
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
params:
n_cond_frames: 1
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True
- input_key: fps_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: motion_bucket_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: cond_frames
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
disable_encoder_autocast: True
n_cond_frames: 1
n_copies: 1
is_ae: True
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
- input_key: cond_aug
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencodingEngine
params:
loss_config:
target: torch.nn.Identity
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: sgm.modules.diffusionmodules.model.Encoder
params:
attn_type: vanilla
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
decoder_config:
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
params:
attn_type: vanilla
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
video_kernel_size: [3, 1, 1]
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
params:
sigma_max: 700.0
guider_config:
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
params:
max_scale: 3.0
min_scale: 1.5

View File

@@ -0,0 +1,129 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
ckpt_path: checkpoints/svd_xt_image_decoder.safetensors
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
network_config:
target: sgm.modules.diffusionmodules.video_model.VideoUNet
params:
adm_in_channels: 768
num_classes: sequential
use_checkpoint: True
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2, 1]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
extra_ff_mix_layer: True
use_spatial_context: True
merge_strategy: learned_with_images
video_kernel_size: [3, 1, 1]
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: cond_frames_without_noise
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
params:
n_cond_frames: 1
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True
- input_key: fps_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: motion_bucket_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: cond_frames
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
disable_encoder_autocast: True
n_cond_frames: 1
n_copies: 1
is_ae: True
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
- input_key: cond_aug
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
params:
sigma_max: 700.0
guider_config:
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
params:
max_scale: 3.0
min_scale: 1.5

View File

@@ -0,0 +1,278 @@
import math
import os
from glob import glob
from pathlib import Path
from typing import Optional
import cv2
import numpy as np
import torch
from einops import rearrange, repeat
from fire import Fire
from omegaconf import OmegaConf
from PIL import Image
from torchvision.transforms import ToTensor
from scripts.util.detection.nsfw_and_watermark_dectection import \
DeepFloydDataFiltering
from sgm.inference.helpers import embed_watermark
from sgm.util import default, instantiate_from_config
def sample(
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
num_frames: Optional[int] = None,
num_steps: Optional[int] = None,
version: str = "svd",
fps_id: int = 6,
motion_bucket_id: int = 127,
cond_aug: float = 0.02,
seed: int = 23,
decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device: str = "cuda",
output_folder: Optional[str] = None,
):
"""
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
"""
if version == "svd":
num_frames = default(num_frames, 14)
num_steps = default(num_steps, 25)
output_folder = default(output_folder, "outputs/simple_video_sample/svd/")
model_config = "scripts/sampling/configs/svd.yaml"
elif version == "svd_xt":
num_frames = default(num_frames, 25)
num_steps = default(num_steps, 30)
output_folder = default(output_folder, "outputs/simple_video_sample/svd_xt/")
model_config = "scripts/sampling/configs/svd_xt.yaml"
elif version == "svd_image_decoder":
num_frames = default(num_frames, 14)
num_steps = default(num_steps, 25)
output_folder = default(
output_folder, "outputs/simple_video_sample/svd_image_decoder/"
)
model_config = "scripts/sampling/configs/svd_image_decoder.yaml"
elif version == "svd_xt_image_decoder":
num_frames = default(num_frames, 25)
num_steps = default(num_steps, 30)
output_folder = default(
output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/"
)
model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml"
else:
raise ValueError(f"Version {version} does not exist.")
model, filter = load_model(
model_config,
device,
num_frames,
num_steps,
)
torch.manual_seed(seed)
path = Path(input_path)
all_img_paths = []
if path.is_file():
if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
all_img_paths = [input_path]
else:
raise ValueError("Path is not valid image file.")
elif path.is_dir():
all_img_paths = sorted(
[
f
for f in path.iterdir()
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
]
)
if len(all_img_paths) == 0:
raise ValueError("Folder does not contain any images.")
else:
raise ValueError
for input_img_path in all_img_paths:
with Image.open(input_img_path) as image:
if image.mode == "RGBA":
image = image.convert("RGB")
w, h = image.size
if h % 64 != 0 or w % 64 != 0:
width, height = map(lambda x: x - x % 64, (w, h))
image = image.resize((width, height))
print(
f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
)
image = ToTensor()(image)
image = image * 2.0 - 1.0
image = image.unsqueeze(0).to(device)
H, W = image.shape[2:]
assert image.shape[1] == 3
F = 8
C = 4
shape = (num_frames, C, H // F, W // F)
if (H, W) != (576, 1024):
print(
"WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
)
if motion_bucket_id > 255:
print(
"WARNING: High motion bucket! This may lead to suboptimal performance."
)
if fps_id < 5:
print("WARNING: Small fps value! This may lead to suboptimal performance.")
if fps_id > 30:
print("WARNING: Large fps value! This may lead to suboptimal performance.")
value_dict = {}
value_dict["motion_bucket_id"] = motion_bucket_id
value_dict["fps_id"] = fps_id
value_dict["cond_aug"] = cond_aug
value_dict["cond_frames_without_noise"] = image
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
value_dict["cond_aug"] = cond_aug
with torch.no_grad():
with torch.autocast(device):
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[1, num_frames],
T=num_frames,
device=device,
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=[
"cond_frames",
"cond_frames_without_noise",
],
)
for k in ["crossattn", "concat"]:
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
randn = torch.randn(shape, device=device)
additional_model_inputs = {}
additional_model_inputs["image_only_indicator"] = torch.zeros(
2, num_frames
).to(device)
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
model.en_and_decode_n_samples_a_time = decoding_t
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
os.makedirs(output_folder, exist_ok=True)
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
writer = cv2.VideoWriter(
video_path,
cv2.VideoWriter_fourcc(*"MP4V"),
fps_id + 1,
(samples.shape[-1], samples.shape[-2]),
)
samples = embed_watermark(samples)
samples = filter(samples)
vid = (
(rearrange(samples, "t c h w -> t h w c") * 255)
.cpu()
.numpy()
.astype(np.uint8)
)
for frame in vid:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
writer.write(frame)
writer.release()
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def get_batch(keys, value_dict, N, T, device):
batch = {}
batch_uc = {}
for key in keys:
if key == "fps_id":
batch[key] = (
torch.tensor([value_dict["fps_id"]])
.to(device)
.repeat(int(math.prod(N)))
)
elif key == "motion_bucket_id":
batch[key] = (
torch.tensor([value_dict["motion_bucket_id"]])
.to(device)
.repeat(int(math.prod(N)))
)
elif key == "cond_aug":
batch[key] = repeat(
torch.tensor([value_dict["cond_aug"]]).to(device),
"1 -> b",
b=math.prod(N),
)
elif key == "cond_frames":
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
elif key == "cond_frames_without_noise":
batch[key] = repeat(
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
)
else:
batch[key] = value_dict[key]
if T is not None:
batch["num_video_frames"] = T
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
def load_model(
config: str,
device: str,
num_frames: int,
num_steps: int,
):
config = OmegaConf.load(config)
if device == "cuda":
config.model.params.conditioner_config.params.emb_models[
0
].params.open_clip_embedding_config.params.init_device = device
config.model.params.sampler_config.params.num_steps = num_steps
config.model.params.sampler_config.params.guider_config.params.num_frames = (
num_frames
)
if device == "cuda":
with torch.device(device):
model = instantiate_from_config(config.model).to(device).eval()
else:
model = instantiate_from_config(config.model).to(device).eval()
filter = DeepFloydDataFiltering(verbose=False, device=device)
return model, filter
if __name__ == "__main__":
Fire(sample)

View File

@@ -1,10 +1,10 @@
import torch
import einops
from torch.backends.cuda import SDPBackend
import torch
import torch.nn.functional as F
import torch.utils.benchmark as benchmark
from torch.backends.cuda import SDPBackend
from sgm.modules.attention import SpatialTransformer, BasicTransformerBlock
from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer
def benchmark_attn():

View File

@@ -37,10 +37,13 @@ def clip_process_images(images: torch.Tensor) -> torch.Tensor:
class DeepFloydDataFiltering(object):
def __init__(self, verbose: bool = False):
def __init__(
self, verbose: bool = False, device: torch.device = torch.device("cpu")
):
super().__init__()
self.verbose = verbose
self.clip_model, _ = clip.load("ViT-L/14", device="cpu")
self._device = None
self.clip_model, _ = clip.load("ViT-L/14", device=device)
self.clip_model.eval()
self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
@@ -54,7 +57,9 @@ class DeepFloydDataFiltering(object):
@torch.inference_mode()
def __call__(self, images: torch.Tensor) -> torch.Tensor:
imgs = clip_process_images(images)
image_features = self.clip_model.encode_image(imgs.to("cpu"))
if self._device is None:
self._device = next(p for p in self.clip_model.parameters()).device
image_features = self.clip_model.encode_image(imgs.to(self._device))
image_features = image_features.detach().cpu().numpy().astype(np.float16)
p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)

View File

@@ -1,23 +1,20 @@
from dataclasses import dataclass, asdict
from enum import Enum
from omegaconf import OmegaConf
import pathlib
from sgm.inference.helpers import (
do_sample,
do_img2img,
Img2ImgDiscretizationWrapper,
)
from sgm.modules.diffusionmodules.sampling import (
EulerEDMSampler,
HeunEDMSampler,
EulerAncestralSampler,
DPMPP2SAncestralSampler,
DPMPP2MSampler,
LinearMultistepSampler,
)
from sgm.util import load_model_from_config
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Optional
from omegaconf import OmegaConf
from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
do_sample)
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler,
HeunEDMSampler,
LinearMultistepSampler)
from sgm.util import load_model_from_config
class ModelArchitecture(str, Enum):
SD_2_1 = "stable-diffusion-v2-1"

View File

@@ -1,13 +1,13 @@
import os
from typing import Union, List, Optional
import math
import os
from typing import List, Optional, Union
import numpy as np
import torch
from PIL import Image
from einops import rearrange
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig
from PIL import Image
from torch import autocast
from sgm.util import append_dims
@@ -20,17 +20,16 @@ class WatermarkEmbedder:
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor):
def __call__(self, image: torch.Tensor) -> torch.Tensor:
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, C, H, W) in range [0, 1]
image: ([N,] B, RGB, H, W) in range [0, 1]
Returns:
same as input but watermarked
"""
# watermarking libary expects input as cv2 BGR format
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
@@ -39,6 +38,7 @@ class WatermarkEmbedder:
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
# watermarking libary expects input as cv2 BGR format
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(

View File

@@ -1,18 +1,22 @@
import logging
import math
import re
from abc import abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import pytorch_lightning as pl
import torch
from omegaconf import ListConfig
import torch.nn as nn
from einops import rearrange
from packaging import version
from safetensors.torch import load_file as load_safetensors
from ..modules.diffusionmodules.model import Decoder, Encoder
from ..modules.distributions.distributions import DiagonalGaussianDistribution
from ..modules.autoencoding.regularizers import AbstractRegularizer
from ..modules.ema import LitEma
from ..util import default, get_obj_from_str, instantiate_from_config
from ..util import (default, get_nested_attribute, get_obj_from_str,
instantiate_from_config)
logpy = logging.getLogger(__name__)
class AbstractAutoencoder(pl.LightningModule):
@@ -27,10 +31,9 @@ class AbstractAutoencoder(pl.LightningModule):
ema_decay: Union[None, float] = None,
monitor: Union[None, str] = None,
input_key: str = "jpg",
ckpt_path: Union[None, str] = None,
ignore_keys: Union[Tuple, list, ListConfig] = (),
):
super().__init__()
self.input_key = input_key
self.use_ema = ema_decay is not None
if monitor is not None:
@@ -38,38 +41,21 @@ class AbstractAutoencoder(pl.LightningModule):
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if version.parse(torch.__version__) >= version.parse("2.0.0"):
self.automatic_optimization = False
def init_from_ckpt(
self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
) -> None:
if path.endswith("ckpt"):
sd = torch.load(path, map_location="cpu")["state_dict"]
elif path.endswith("safetensors"):
sd = load_safetensors(path)
else:
raise NotImplementedError
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if re.match(ik, k):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
return
if isinstance(ckpt, str):
ckpt = {
"target": "sgm.modules.checkpoint.CheckpointEngine",
"params": {"ckpt_path": ckpt},
}
engine = instantiate_from_config(ckpt)
engine(self)
@abstractmethod
def get_input(self, batch) -> Any:
@@ -86,14 +72,14 @@ class AbstractAutoencoder(pl.LightningModule):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
logpy.info(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
logpy.info(f"{context}: Restored training weights")
@abstractmethod
def encode(self, *args, **kwargs) -> torch.Tensor:
@@ -104,7 +90,7 @@ class AbstractAutoencoder(pl.LightningModule):
raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg):
print(f"loading >>> {cfg['target']} <<< optimizer from config")
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
@@ -129,196 +115,435 @@ class AutoencodingEngine(AbstractAutoencoder):
regularizer_config: Dict,
optimizer_config: Union[Dict, None] = None,
lr_g_factor: float = 1.0,
trainable_ae_params: Optional[List[List[str]]] = None,
ae_optimizer_args: Optional[List[dict]] = None,
trainable_disc_params: Optional[List[List[str]]] = None,
disc_optimizer_args: Optional[List[dict]] = None,
disc_start_iter: int = 0,
diff_boost_factor: float = 3.0,
ckpt_engine: Union[None, str, dict] = None,
ckpt_path: Optional[str] = None,
additional_decode_keys: Optional[List[str]] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
# todo: add options to freeze encoder/decoder
self.encoder = instantiate_from_config(encoder_config)
self.decoder = instantiate_from_config(decoder_config)
self.loss = instantiate_from_config(loss_config)
self.regularization = instantiate_from_config(regularizer_config)
self.automatic_optimization = False # pytorch lightning
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
self.loss: torch.nn.Module = instantiate_from_config(loss_config)
self.regularization: AbstractRegularizer = instantiate_from_config(
regularizer_config
)
self.optimizer_config = default(
optimizer_config, {"target": "torch.optim.Adam"}
)
self.diff_boost_factor = diff_boost_factor
self.disc_start_iter = disc_start_iter
self.lr_g_factor = lr_g_factor
self.trainable_ae_params = trainable_ae_params
if self.trainable_ae_params is not None:
self.ae_optimizer_args = default(
ae_optimizer_args,
[{} for _ in range(len(self.trainable_ae_params))],
)
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
else:
self.ae_optimizer_args = [{}] # makes type consitent
self.trainable_disc_params = trainable_disc_params
if self.trainable_disc_params is not None:
self.disc_optimizer_args = default(
disc_optimizer_args,
[{} for _ in range(len(self.trainable_disc_params))],
)
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
else:
self.disc_optimizer_args = [{}] # makes type consitent
if ckpt_path is not None:
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
self.apply_ckpt(default(ckpt_path, ckpt_engine))
self.additional_decode_keys = set(default(additional_decode_keys, []))
def get_input(self, batch: Dict) -> torch.Tensor:
# assuming unified data format, dataloader returns a dict.
# image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
# image tensors should be scaled to -1 ... 1 and in channels-first
# format (e.g., bchw instead if bhwc)
return batch[self.input_key]
def get_autoencoder_params(self) -> list:
params = (
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.regularization.get_trainable_parameters())
+ list(self.loss.get_trainable_autoencoder_parameters())
)
params = []
if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
params += list(self.loss.get_trainable_autoencoder_parameters())
if hasattr(self.regularization, "get_trainable_parameters"):
params += list(self.regularization.get_trainable_parameters())
params = params + list(self.encoder.parameters())
params = params + list(self.decoder.parameters())
return params
def get_discriminator_params(self) -> list:
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
if hasattr(self.loss, "get_trainable_parameters"):
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
else:
params = []
return params
def get_last_layer(self):
return self.decoder.get_last_layer()
def encode(self, x: Any, return_reg_log: bool = False) -> Any:
def encode(
self,
x: torch.Tensor,
return_reg_log: bool = False,
unregularized: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
z = self.encoder(x)
if unregularized:
return z, dict()
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: Any) -> torch.Tensor:
x = self.decoder(z)
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.decoder(z, **kwargs)
return x
def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def forward(
self, x: torch.Tensor, **additional_decode_kwargs
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z)
dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log
def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
def inner_training_step(
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
) -> torch.Tensor:
x = self.get_input(batch)
z, xrec, regularization_log = self(x)
additional_decode_kwargs = {
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
}
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
if hasattr(self.loss, "forward_keys"):
extra_info = {
"z": z,
"optimizer_idx": optimizer_idx,
"global_step": self.global_step,
"last_layer": self.get_last_layer(),
"split": "train",
"regularization_log": regularization_log,
"autoencoder": self,
}
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
else:
extra_info = dict()
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(
regularization_log,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
out_loss = self.loss(x, xrec, **extra_info)
if isinstance(out_loss, tuple):
aeloss, log_dict_ae = out_loss
else:
# simple loss function
aeloss = out_loss
log_dict_ae = {"train/loss/rec": aeloss.detach()}
self.log_dict(
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
sync_dist=False,
)
self.log(
"loss",
aeloss.mean().detach(),
prog_bar=True,
logger=False,
on_epoch=False,
on_step=True,
)
return aeloss
if optimizer_idx == 1:
elif optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(
regularization_log,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
# -> discriminator always needs to return a tuple
self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
)
return discloss
else:
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
def validation_step(self, batch, batch_idx) -> Dict:
def training_step(self, batch: dict, batch_idx: int):
opts = self.optimizers()
if not isinstance(opts, list):
# Non-adversarial case
opts = [opts]
optimizer_idx = batch_idx % len(opts)
if self.global_step < self.disc_start_iter:
optimizer_idx = 0
opt = opts[optimizer_idx]
opt.zero_grad()
with opt.toggle_model():
loss = self.inner_training_step(
batch, batch_idx, optimizer_idx=optimizer_idx
)
self.manual_backward(loss)
opt.step()
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
log_dict.update(log_dict_ema)
return log_dict
def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
x = self.get_input(batch)
z, xrec, regularization_log = self(x)
aeloss, log_dict_ae = self.loss(
regularization_log,
x,
xrec,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + postfix,
if hasattr(self.loss, "forward_keys"):
extra_info = {
"z": z,
"optimizer_idx": 0,
"global_step": self.global_step,
"last_layer": self.get_last_layer(),
"split": "val" + postfix,
"regularization_log": regularization_log,
"autoencoder": self,
}
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
else:
extra_info = dict()
out_loss = self.loss(x, xrec, **extra_info)
if isinstance(out_loss, tuple):
aeloss, log_dict_ae = out_loss
else:
# simple loss function
aeloss = out_loss
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
full_log_dict = log_dict_ae
if "optimizer_idx" in extra_info:
extra_info["optimizer_idx"] = 1
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
full_log_dict.update(log_dict_disc)
self.log(
f"val{postfix}/loss/rec",
log_dict_ae[f"val{postfix}/loss/rec"],
sync_dist=True,
)
self.log_dict(full_log_dict, sync_dist=True)
return full_log_dict
discloss, log_dict_disc = self.loss(
regularization_log,
x,
xrec,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + postfix,
)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
log_dict_ae.update(log_dict_disc)
self.log_dict(log_dict_ae)
return log_dict_ae
def configure_optimizers(self) -> Any:
ae_params = self.get_autoencoder_params()
disc_params = self.get_discriminator_params()
def get_param_groups(
self, parameter_names: List[List[str]], optimizer_args: List[dict]
) -> Tuple[List[Dict[str, Any]], int]:
groups = []
num_params = 0
for names, args in zip(parameter_names, optimizer_args):
params = []
for pattern_ in names:
pattern_params = []
pattern = re.compile(pattern_)
for p_name, param in self.named_parameters():
if re.match(pattern, p_name):
pattern_params.append(param)
num_params += param.numel()
if len(pattern_params) == 0:
logpy.warn(f"Did not find parameters for pattern {pattern_}")
params.extend(pattern_params)
groups.append({"params": params, **args})
return groups, num_params
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
if self.trainable_ae_params is None:
ae_params = self.get_autoencoder_params()
else:
ae_params, num_ae_params = self.get_param_groups(
self.trainable_ae_params, self.ae_optimizer_args
)
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
if self.trainable_disc_params is None:
disc_params = self.get_discriminator_params()
else:
disc_params, num_disc_params = self.get_param_groups(
self.trainable_disc_params, self.disc_optimizer_args
)
logpy.info(
f"Number of trainable discriminator parameters: {num_disc_params:,}"
)
opt_ae = self.instantiate_optimizer_from_config(
ae_params,
default(self.lr_g_factor, 1.0) * self.learning_rate,
self.optimizer_config,
)
opt_disc = self.instantiate_optimizer_from_config(
disc_params, self.learning_rate, self.optimizer_config
)
opts = [opt_ae]
if len(disc_params) > 0:
opt_disc = self.instantiate_optimizer_from_config(
disc_params, self.learning_rate, self.optimizer_config
)
opts.append(opt_disc)
return [opt_ae, opt_disc], []
return opts
@torch.no_grad()
def log_images(self, batch: Dict, **kwargs) -> Dict:
def log_images(
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
) -> dict:
log = dict()
additional_decode_kwargs = {}
x = self.get_input(batch)
_, xrec, _ = self(x)
additional_decode_kwargs.update(
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
)
_, xrec, _ = self(x, **additional_decode_kwargs)
log["inputs"] = x
log["reconstructions"] = xrec
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
diff.clamp_(0, 1.0)
log["diff"] = 2.0 * diff - 1.0
# diff_boost shows location of small errors, by boosting their
# brightness.
log["diff_boost"] = (
2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
)
if hasattr(self.loss, "log_images"):
log.update(self.loss.log_images(x, xrec))
with self.ema_scope():
_, xrec_ema, _ = self(x)
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
log["reconstructions_ema"] = xrec_ema
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
diff_ema.clamp_(0, 1.0)
log["diff_ema"] = 2.0 * diff_ema - 1.0
log["diff_boost_ema"] = (
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
)
if additional_log_kwargs:
additional_decode_kwargs.update(additional_log_kwargs)
_, xrec_add, _ = self(x, **additional_decode_kwargs)
log_str = "reconstructions-" + "-".join(
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
)
log[log_str] = xrec_add
return log
class AutoencoderKL(AutoencodingEngine):
class AutoencodingEngineLegacy(AutoencodingEngine):
def __init__(self, embed_dim: int, **kwargs):
self.max_batch_size = kwargs.pop("max_batch_size", None)
ddconfig = kwargs.pop("ddconfig")
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", ())
ckpt_engine = kwargs.pop("ckpt_engine", None)
super().__init__(
encoder_config={"target": "torch.nn.Identity"},
decoder_config={"target": "torch.nn.Identity"},
regularizer_config={"target": "torch.nn.Identity"},
loss_config=kwargs.pop("lossconfig"),
encoder_config={
"target": "sgm.modules.diffusionmodules.model.Encoder",
"params": ddconfig,
},
decoder_config={
"target": "sgm.modules.diffusionmodules.model.Decoder",
"params": ddconfig,
},
**kwargs,
)
assert ddconfig["double_z"]
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.quant_conv = torch.nn.Conv2d(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.apply_ckpt(default(ckpt_path, ckpt_engine))
def encode(self, x):
assert (
not self.training
), f"{self.__class__.__name__} only supports inference currently"
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params()
return params
def encode(
self, x: torch.Tensor, return_reg_log: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.max_batch_size is None:
z = self.encoder(x)
z = self.quant_conv(z)
else:
N = x.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
z = list()
for i_batch in range(n_batches):
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
z_batch = self.quant_conv(z_batch)
z.append(z_batch)
z = torch.cat(z, 0)
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.max_batch_size is None:
dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs)
else:
N = z.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
dec = list()
for i_batch in range(n_batches):
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
dec.append(dec_batch)
dec = torch.cat(dec, 0)
def decode(self, z, **decoder_kwargs):
z = self.post_quant_conv(z)
dec = self.decoder(z, **decoder_kwargs)
return dec
class AutoencoderKLInferenceWrapper(AutoencoderKL):
def encode(self, x):
return super().encode(x).sample()
class AutoencoderKL(AutoencodingEngineLegacy):
def __init__(self, **kwargs):
if "lossconfig" in kwargs:
kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__(
regularizer_config={
"target": (
"sgm.modules.autoencoding.regularizers"
".DiagonalGaussianRegularizer"
)
},
**kwargs,
)
class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
def __init__(
self,
embed_dim: int,
n_embed: int,
sane_index_shape: bool = False,
**kwargs,
):
if "lossconfig" in kwargs:
logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__(
regularizer_config={
"target": (
"sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
),
"params": {
"n_e": n_embed,
"e_dim": embed_dim,
"sane_index_shape": sane_index_shape,
},
},
**kwargs,
)
class IdentityFirstStage(AbstractAutoencoder):
@@ -333,3 +558,58 @@ class IdentityFirstStage(AbstractAutoencoder):
def decode(self, x: Any, *args, **kwargs) -> Any:
return x
class AEIntegerWrapper(nn.Module):
def __init__(
self,
model: nn.Module,
shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
regularization_key: str = "regularization",
encoder_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__()
self.model = model
assert hasattr(model, "encode") and hasattr(
model, "decode"
), "Need AE interface"
self.regularization = get_nested_attribute(model, regularization_key)
self.shape = shape
self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
def encode(self, x) -> torch.Tensor:
assert (
not self.training
), f"{self.__class__.__name__} only supports inference currently"
_, log = self.model.encode(x, **self.encoder_kwargs)
assert isinstance(log, dict)
inds = log["min_encoding_indices"]
return rearrange(inds, "b ... -> b (...)")
def decode(
self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
) -> torch.Tensor:
# expect inds shape (b, s) with s = h*w
shape = default(shape, self.shape) # Optional[(h, w)]
if shape is not None:
assert len(shape) == 2, f"Unhandeled shape {shape}"
inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
h = rearrange(h, "b h w c -> b c h w")
return self.model.decode(h)
class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
def __init__(self, **kwargs):
if "lossconfig" in kwargs:
kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__(
regularizer_config={
"target": (
"sgm.modules.autoencoding.regularizers"
".DiagonalGaussianRegularizer"
),
"params": {"sample": False},
},
**kwargs,
)

View File

@@ -1,5 +1,6 @@
import math
from contextlib import contextmanager
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import pytorch_lightning as pl
import torch
@@ -8,15 +9,11 @@ from safetensors.torch import load_file as load_safetensors
from torch.optim.lr_scheduler import LambdaLR
from ..modules import UNCONDITIONAL_CONFIG
from ..modules.autoencoding.temporal_ae import VideoDecoder
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from ..modules.ema import LitEma
from ..util import (
default,
disabled_train,
get_obj_from_str,
instantiate_from_config,
log_txt_as_img,
)
from ..util import (default, disabled_train, get_obj_from_str,
instantiate_from_config, log_txt_as_img)
class DiffusionEngine(pl.LightningModule):
@@ -40,6 +37,7 @@ class DiffusionEngine(pl.LightningModule):
log_keys: Union[List, None] = None,
no_cond_log: bool = False,
compile_model: bool = False,
en_and_decode_n_samples_a_time: Optional[int] = None,
):
super().__init__()
self.log_keys = log_keys
@@ -82,6 +80,8 @@ class DiffusionEngine(pl.LightningModule):
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path)
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
def init_from_ckpt(
self,
path: str,
@@ -117,14 +117,35 @@ class DiffusionEngine(pl.LightningModule):
@torch.no_grad()
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
out = self.first_stage_model.decode(z)
for n in range(n_rounds):
if isinstance(self.first_stage_model.decoder, VideoDecoder):
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
else:
kwargs = {}
out = self.first_stage_model.decode(
z[n * n_samples : (n + 1) * n_samples], **kwargs
)
all_out.append(out)
out = torch.cat(all_out, dim=0)
return out
@torch.no_grad()
def encode_first_stage(self, x):
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
n_rounds = math.ceil(x.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
z = self.first_stage_model.encode(x)
for n in range(n_rounds):
out = self.first_stage_model.encode(
x[n * n_samples : (n + 1) * n_samples]
)
all_out.append(out)
z = torch.cat(all_out, dim=0)
z = self.scale_factor * z
return z

View File

@@ -1,3 +1,4 @@
import logging
import math
from inspect import isfunction
from typing import Any, Optional
@@ -7,6 +8,9 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from packaging import version
from torch import nn
from torch.utils.checkpoint import checkpoint
logpy = logging.getLogger(__name__)
if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True
@@ -36,9 +40,10 @@ else:
SDP_IS_AVAILABLE = False
sdp_kernel = nullcontext
BACKEND_MAP = {}
print(
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
logpy.warn(
f"No SDP backend available, likely because you are running in pytorch "
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
f"You might want to consider upgrading."
)
try:
@@ -48,9 +53,9 @@ try:
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
print("no module 'xformers'. Processing without...")
logpy.warn("no module 'xformers'. Processing without...")
from .diffusionmodules.util import checkpoint
# from .diffusionmodules.util import mixed_checkpoint as checkpoint
def exists(val):
@@ -146,6 +151,62 @@ class LinearAttention(nn.Module):
return self.to_out(out)
class SelfAttention(nn.Module):
ATTENTION_MODES = ("xformers", "torch", "math")
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: Optional[float] = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
attn_mode: str = "xformers",
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
assert attn_mode in self.ATTENTION_MODES
self.attn_mode = attn_mode
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, L, C = x.shape
qkv = self.qkv(x)
if self.attn_mode == "torch":
qkv = rearrange(
qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
).float()
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
elif self.attn_mode == "xformers":
qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
x = xformers.ops.memory_efficient_attention(q, k, v)
x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
elif self.attn_mode == "math":
qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
else:
raise NotImplemented
x = self.proj(x)
x = self.proj_drop(x)
return x
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
@@ -289,9 +350,10 @@ class MemoryEfficientCrossAttention(nn.Module):
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
):
super().__init__()
print(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads with a dimension of {dim_head}."
logpy.debug(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
f"context_dim is {context_dim} and using {heads} heads with a "
f"dimension of {dim_head}."
)
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
@@ -352,9 +414,29 @@ class MemoryEfficientCrossAttention(nn.Module):
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op
)
if version.parse(xformers.__version__) >= version.parse("0.0.21"):
# NOTE: workaround for
# https://github.com/facebookresearch/xformers/issues/845
max_bs = 32768
N = q.shape[0]
n_batches = math.ceil(N / max_bs)
out = list()
for i_batch in range(n_batches):
batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
out.append(
xformers.ops.memory_efficient_attention(
q[batch],
k[batch],
v[batch],
attn_bias=None,
op=self.attention_op,
)
)
out = torch.cat(out, 0)
else:
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op
)
# TODO: Use this directly in the attention operation, as a bias
if exists(mask):
@@ -393,21 +475,24 @@ class BasicTransformerBlock(nn.Module):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
print(
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
logpy.warn(
f"Attention mode '{attn_mode}' is not available. Falling "
f"back to native attention. This is not a problem in "
f"Pytorch >= 2.0. FYI, you are running with PyTorch "
f"version {torch.__version__}."
)
attn_mode = "softmax"
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
print(
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
logpy.warn(
"We do not support vanilla attention anymore, as it is too "
"expensive. Sorry."
)
if not XFORMERS_IS_AVAILABLE:
assert (
False
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
else:
print("Falling back to xformers efficient attention.")
logpy.info("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
attn_cls = self.ATTENTION_MODES[attn_mode]
if version.parse(torch.__version__) >= version.parse("2.0.0"):
@@ -437,7 +522,7 @@ class BasicTransformerBlock(nn.Module):
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
if self.checkpoint:
print(f"{self.__class__.__name__} is using checkpointing")
logpy.debug(f"{self.__class__.__name__} is using checkpointing")
def forward(
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
@@ -456,9 +541,12 @@ class BasicTransformerBlock(nn.Module):
)
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
if self.checkpoint:
# inputs = {"x": x, "context": context}
return checkpoint(self._forward, x, context)
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
else:
return self._forward(**kwargs)
def _forward(
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
@@ -518,9 +606,9 @@ class BasicTransformerSingleLayerBlock(nn.Module):
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
# inputs = {"x": x, "context": context}
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
return checkpoint(self._forward, x, context)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context) + x
@@ -554,18 +642,20 @@ class SpatialTransformer(nn.Module):
sdp_backend=None,
):
super().__init__()
print(
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
logpy.debug(
f"constructing {self.__class__.__name__} of depth {depth} w/ "
f"{in_channels} channels and {n_heads} heads."
)
from omegaconf import ListConfig
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
if exists(context_dim) and isinstance(context_dim, list):
if depth != len(context_dim):
print(
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
logpy.warn(
f"{self.__class__.__name__}: Found context dims "
f"{context_dim} of depth {len(context_dim)}, which does not "
f"match the specified 'depth' of {depth}. Setting context_dim "
f"to {depth * [context_dim[0]]} now."
)
# depth does not match context dims.
assert all(
@@ -631,3 +721,39 @@ class SpatialTransformer(nn.Module):
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
class SimpleTransformer(nn.Module):
def __init__(
self,
dim: int,
depth: int,
heads: int,
dim_head: int,
context_dim: Optional[int] = None,
dropout: float = 0.0,
checkpoint: bool = True,
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
BasicTransformerBlock(
dim,
heads,
dim_head,
dropout=dropout,
context_dim=context_dim,
attn_mode="softmax-xformers",
checkpoint=checkpoint,
)
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for layer in self.layers:
x = layer(x, context)
return x

View File

@@ -1,246 +1,7 @@
from typing import Any, Union
__all__ = [
"GeneralLPIPSWithDiscriminator",
"LatentLPIPS",
]
import torch
import torch.nn as nn
from einops import rearrange
from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS
from ..lpips.model.model import NLayerDiscriminator, weights_init
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
class LatentLPIPS(nn.Module):
def __init__(
self,
decoder_config,
perceptual_weight=1.0,
latent_weight=1.0,
scale_input_to_tgt_size=False,
scale_tgt_to_input_size=False,
perceptual_weight_on_inputs=0.0,
):
super().__init__()
self.scale_input_to_tgt_size = scale_input_to_tgt_size
self.scale_tgt_to_input_size = scale_tgt_to_input_size
self.init_decoder(decoder_config)
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
self.latent_weight = latent_weight
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
def init_decoder(self, config):
self.decoder = instantiate_from_config(config)
if hasattr(self.decoder, "encoder"):
del self.decoder.encoder
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
log = dict()
loss = (latent_inputs - latent_predictions) ** 2
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
image_reconstructions = None
if self.perceptual_weight > 0.0:
image_reconstructions = self.decoder.decode(latent_predictions)
image_targets = self.decoder.decode(latent_inputs)
perceptual_loss = self.perceptual_loss(
image_targets.contiguous(), image_reconstructions.contiguous()
)
loss = (
self.latent_weight * loss.mean()
+ self.perceptual_weight * perceptual_loss.mean()
)
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
if self.perceptual_weight_on_inputs > 0.0:
image_reconstructions = default(
image_reconstructions, self.decoder.decode(latent_predictions)
)
if self.scale_input_to_tgt_size:
image_inputs = torch.nn.functional.interpolate(
image_inputs,
image_reconstructions.shape[2:],
mode="bicubic",
antialias=True,
)
elif self.scale_tgt_to_input_size:
image_reconstructions = torch.nn.functional.interpolate(
image_reconstructions,
image_inputs.shape[2:],
mode="bicubic",
antialias=True,
)
perceptual_loss2 = self.perceptual_loss(
image_inputs.contiguous(), image_reconstructions.contiguous()
)
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
return loss, log
class GeneralLPIPSWithDiscriminator(nn.Module):
def __init__(
self,
disc_start: int,
logvar_init: float = 0.0,
pixelloss_weight=1.0,
disc_num_layers: int = 3,
disc_in_channels: int = 3,
disc_factor: float = 1.0,
disc_weight: float = 1.0,
perceptual_weight: float = 1.0,
disc_loss: str = "hinge",
scale_input_to_tgt_size: bool = False,
dims: int = 2,
learn_logvar: bool = False,
regularization_weights: Union[None, dict] = None,
):
super().__init__()
self.dims = dims
if self.dims > 2:
print(
f"running with dims={dims}. This means that for perceptual loss calculation, "
f"the LPIPS loss will be applied to each frame independently. "
)
self.scale_input_to_tgt_size = scale_input_to_tgt_size
assert disc_loss in ["hinge", "vanilla"]
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.learn_logvar = learn_logvar
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False
).apply(weights_init)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.regularization_weights = default(regularization_weights, {})
def get_trainable_parameters(self) -> Any:
return self.discriminator.parameters()
def get_trainable_autoencoder_parameters(self) -> Any:
if self.learn_logvar:
yield self.logvar
yield from ()
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
regularization_log,
inputs,
reconstructions,
optimizer_idx,
global_step,
last_layer=None,
split="train",
weights=None,
):
if self.scale_input_to_tgt_size:
inputs = torch.nn.functional.interpolate(
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
)
if self.dims > 2:
inputs, reconstructions = map(
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
(inputs, reconstructions),
)
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
# now the GAN part
if optimizer_idx == 0:
# generator update
logits_fake = self.discriminator(reconstructions.contiguous())
g_loss = -torch.mean(logits_fake)
if self.disc_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
loss = weighted_nll_loss + d_weight * disc_factor * g_loss
log = dict()
for k in regularization_log:
if k in self.regularization_weights:
loss = loss + self.regularization_weights[k] * regularization_log[k]
log[f"{split}/{k}"] = regularization_log[k].detach().mean()
log.update(
{
"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/logvar".format(split): self.logvar.detach(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
)
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean(),
}
return d_loss, log
from .discriminator_loss import GeneralLPIPSWithDiscriminator
from .lpips import LatentLPIPS

View File

@@ -0,0 +1,306 @@
from typing import Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torchvision
from einops import rearrange
from matplotlib import colormaps
from matplotlib import pyplot as plt
from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS
from ..lpips.model.model import weights_init
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
class GeneralLPIPSWithDiscriminator(nn.Module):
def __init__(
self,
disc_start: int,
logvar_init: float = 0.0,
disc_num_layers: int = 3,
disc_in_channels: int = 3,
disc_factor: float = 1.0,
disc_weight: float = 1.0,
perceptual_weight: float = 1.0,
disc_loss: str = "hinge",
scale_input_to_tgt_size: bool = False,
dims: int = 2,
learn_logvar: bool = False,
regularization_weights: Union[None, Dict[str, float]] = None,
additional_log_keys: Optional[List[str]] = None,
discriminator_config: Optional[Dict] = None,
):
super().__init__()
self.dims = dims
if self.dims > 2:
print(
f"running with dims={dims}. This means that for perceptual loss "
f"calculation, the LPIPS loss will be applied to each frame "
f"independently."
)
self.scale_input_to_tgt_size = scale_input_to_tgt_size
assert disc_loss in ["hinge", "vanilla"]
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# output log variance
self.logvar = nn.Parameter(
torch.full((), logvar_init), requires_grad=learn_logvar
)
self.learn_logvar = learn_logvar
discriminator_config = default(
discriminator_config,
{
"target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
"params": {
"input_nc": disc_in_channels,
"n_layers": disc_num_layers,
"use_actnorm": False,
},
},
)
self.discriminator = instantiate_from_config(discriminator_config).apply(
weights_init
)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.regularization_weights = default(regularization_weights, {})
self.forward_keys = [
"optimizer_idx",
"global_step",
"last_layer",
"split",
"regularization_log",
]
self.additional_log_keys = set(default(additional_log_keys, []))
self.additional_log_keys.update(set(self.regularization_weights.keys()))
def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
return self.discriminator.parameters()
def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
if self.learn_logvar:
yield self.logvar
yield from ()
@torch.no_grad()
def log_images(
self, inputs: torch.Tensor, reconstructions: torch.Tensor
) -> Dict[str, torch.Tensor]:
# calc logits of real/fake
logits_real = self.discriminator(inputs.contiguous().detach())
if len(logits_real.shape) < 4:
# Non patch-discriminator
return dict()
logits_fake = self.discriminator(reconstructions.contiguous().detach())
# -> (b, 1, h, w)
# parameters for colormapping
high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
cmap = colormaps["PiYG"] # diverging colormap
def to_colormap(logits: torch.Tensor) -> torch.Tensor:
"""(b, 1, ...) -> (b, 3, ...)"""
logits = (logits + high) / (2 * high)
logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
# -> (b, 1, ..., 3)
logits = torch.from_numpy(logits_np).to(logits.device)
return rearrange(logits, "b 1 ... c -> b c ...")
logits_real = torch.nn.functional.interpolate(
logits_real,
size=inputs.shape[-2:],
mode="nearest",
antialias=False,
)
logits_fake = torch.nn.functional.interpolate(
logits_fake,
size=reconstructions.shape[-2:],
mode="nearest",
antialias=False,
)
# alpha value of logits for overlay
alpha_real = torch.abs(logits_real) / high
alpha_fake = torch.abs(logits_fake) / high
# -> (b, 1, h, w) in range [0, 0.5]
# alpha value of lines don't really matter, since the values are the same
# for both images and logits anyway
grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
# -> (1, h, w)
# blend logits and images together
# prepare logits for plotting
logits_real = to_colormap(logits_real)
logits_fake = to_colormap(logits_fake)
# resize logits
# -> (b, 3, h, w)
# make some grids
# add all logits to one plot
logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
# I just love how torchvision calls the number of columns `nrow`
grid_logits = torch.cat((logits_real, logits_fake), dim=1)
# -> (3, h, w)
grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
grid_images_fake = torchvision.utils.make_grid(
0.5 * reconstructions + 0.5, nrow=4
)
grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
# -> (3, h, w) in range [0, 1]
grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
# Create labeled colorbar
dpi = 100
height = 128 / dpi
width = grid_logits.shape[2] / dpi
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
plt.colorbar(
img,
cax=ax,
orientation="horizontal",
fraction=0.9,
aspect=width / height,
pad=0.0,
)
img.set_visible(False)
fig.tight_layout()
fig.canvas.draw()
# manually convert figure to numpy
cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
# Add colorbar to plot
annotated_grid = torch.cat((grid_logits, cbar), dim=1)
blended_grid = torch.cat((grid_blend, cbar), dim=1)
return {
"vis_logits": 2 * annotated_grid[None, ...] - 1,
"vis_logits_blended": 2 * blended_grid[None, ...] - 1,
}
def calculate_adaptive_weight(
self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
) -> torch.Tensor:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
inputs: torch.Tensor,
reconstructions: torch.Tensor,
*, # added because I changed the order here
regularization_log: Dict[str, torch.Tensor],
optimizer_idx: int,
global_step: int,
last_layer: torch.Tensor,
split: str = "train",
weights: Union[None, float, torch.Tensor] = None,
) -> Tuple[torch.Tensor, dict]:
if self.scale_input_to_tgt_size:
inputs = torch.nn.functional.interpolate(
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
)
if self.dims > 2:
inputs, reconstructions = map(
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
(inputs, reconstructions),
)
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
# now the GAN part
if optimizer_idx == 0:
# generator update
if global_step >= self.discriminator_iter_start or not self.training:
logits_fake = self.discriminator(reconstructions.contiguous())
g_loss = -torch.mean(logits_fake)
if self.training:
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
else:
d_weight = torch.tensor(1.0)
else:
d_weight = torch.tensor(0.0)
g_loss = torch.tensor(0.0, requires_grad=True)
loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
log = dict()
for k in regularization_log:
if k in self.regularization_weights:
loss = loss + self.regularization_weights[k] * regularization_log[k]
if k in self.additional_log_keys:
log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
log.update(
{
f"{split}/loss/total": loss.clone().detach().mean(),
f"{split}/loss/nll": nll_loss.detach().mean(),
f"{split}/loss/rec": rec_loss.detach().mean(),
f"{split}/loss/g": g_loss.detach().mean(),
f"{split}/scalars/logvar": self.logvar.detach(),
f"{split}/scalars/d_weight": d_weight.detach(),
}
)
return loss, log
elif optimizer_idx == 1:
# second pass for discriminator update
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
if global_step >= self.discriminator_iter_start or not self.training:
d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
else:
d_loss = torch.tensor(0.0, requires_grad=True)
log = {
f"{split}/loss/disc": d_loss.clone().detach().mean(),
f"{split}/logits/real": logits_real.detach().mean(),
f"{split}/logits/fake": logits_fake.detach().mean(),
}
return d_loss, log
else:
raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
def get_nll_loss(
self,
rec_loss: torch.Tensor,
weights: Optional[Union[float, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
return nll_loss, weighted_nll_loss

View File

@@ -0,0 +1,73 @@
import torch
import torch.nn as nn
from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS
class LatentLPIPS(nn.Module):
def __init__(
self,
decoder_config,
perceptual_weight=1.0,
latent_weight=1.0,
scale_input_to_tgt_size=False,
scale_tgt_to_input_size=False,
perceptual_weight_on_inputs=0.0,
):
super().__init__()
self.scale_input_to_tgt_size = scale_input_to_tgt_size
self.scale_tgt_to_input_size = scale_tgt_to_input_size
self.init_decoder(decoder_config)
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
self.latent_weight = latent_weight
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
def init_decoder(self, config):
self.decoder = instantiate_from_config(config)
if hasattr(self.decoder, "encoder"):
del self.decoder.encoder
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
log = dict()
loss = (latent_inputs - latent_predictions) ** 2
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
image_reconstructions = None
if self.perceptual_weight > 0.0:
image_reconstructions = self.decoder.decode(latent_predictions)
image_targets = self.decoder.decode(latent_inputs)
perceptual_loss = self.perceptual_loss(
image_targets.contiguous(), image_reconstructions.contiguous()
)
loss = (
self.latent_weight * loss.mean()
+ self.perceptual_weight * perceptual_loss.mean()
)
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
if self.perceptual_weight_on_inputs > 0.0:
image_reconstructions = default(
image_reconstructions, self.decoder.decode(latent_predictions)
)
if self.scale_input_to_tgt_size:
image_inputs = torch.nn.functional.interpolate(
image_inputs,
image_reconstructions.shape[2:],
mode="bicubic",
antialias=True,
)
elif self.scale_tgt_to_input_size:
image_reconstructions = torch.nn.functional.interpolate(
image_reconstructions,
image_inputs.shape[2:],
mode="bicubic",
antialias=True,
)
perceptual_loss2 = self.perceptual_loss(
image_inputs.contiguous(), image_reconstructions.contiguous()
)
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
return loss, log

View File

@@ -5,19 +5,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ....modules.distributions.distributions import DiagonalGaussianDistribution
class AbstractRegularizer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
raise NotImplementedError()
@abstractmethod
def get_trainable_parameters(self) -> Any:
raise NotImplementedError()
from ....modules.distributions.distributions import \
DiagonalGaussianDistribution
from .base import AbstractRegularizer
class DiagonalGaussianRegularizer(AbstractRegularizer):
@@ -39,15 +29,3 @@ class DiagonalGaussianRegularizer(AbstractRegularizer):
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log
def measure_perplexity(predicted_indices, num_centroids):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = (
F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use

View File

@@ -0,0 +1,40 @@
from abc import abstractmethod
from typing import Any, Tuple
import torch
import torch.nn.functional as F
from torch import nn
class AbstractRegularizer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
raise NotImplementedError()
@abstractmethod
def get_trainable_parameters(self) -> Any:
raise NotImplementedError()
class IdentityRegularizer(AbstractRegularizer):
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
return z, dict()
def get_trainable_parameters(self) -> Any:
yield from ()
def measure_perplexity(
predicted_indices: torch.Tensor, num_centroids: int
) -> Tuple[torch.Tensor, torch.Tensor]:
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = (
F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use

View File

@@ -0,0 +1,487 @@
import logging
from abc import abstractmethod
from typing import Dict, Iterator, Literal, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import einsum
from .base import AbstractRegularizer, measure_perplexity
logpy = logging.getLogger(__name__)
class AbstractQuantizer(AbstractRegularizer):
def __init__(self):
super().__init__()
# Define these in your init
# shape (N,)
self.used: Optional[torch.Tensor]
self.re_embed: int
self.unknown_index: Union[Literal["random"], int]
def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor:
assert self.used is not None, "You need to define used indices for remap"
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
match = (inds[:, :, None] == used[None, None, ...]).long()
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
device=new.device
)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor:
assert self.used is not None, "You need to define used indices for remap"
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
@abstractmethod
def get_codebook_entry(
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
) -> torch.Tensor:
raise NotImplementedError()
def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
yield from self.parameters()
class GumbelQuantizer(AbstractQuantizer):
"""
credit to @karpathy:
https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
Gumbel Softmax trick quantizer
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
https://arxiv.org/abs/1611.01144
"""
def __init__(
self,
num_hiddens: int,
embedding_dim: int,
n_embed: int,
straight_through: bool = True,
kl_weight: float = 5e-4,
temp_init: float = 1.0,
remap: Optional[str] = None,
unknown_index: str = "random",
loss_key: str = "loss/vq",
) -> None:
super().__init__()
self.loss_key = loss_key
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
self.embed = nn.Embedding(n_embed, embedding_dim)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
else:
self.used = None
self.re_embed = n_embed
if unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
else:
assert unknown_index == "random" or isinstance(
unknown_index, int
), "unknown index needs to be 'random', 'extra' or any integer"
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.remap is not None:
logpy.info(
f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices."
)
def forward(
self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False
) -> Tuple[torch.Tensor, Dict]:
# force hard = True when we are in eval mode, as we must quantize.
# actually, always true seems to work
hard = self.straight_through if self.training else True
temp = self.temperature if temp is None else temp
out_dict = {}
logits = self.proj(z)
if self.remap is not None:
# continue only with used logits
full_zeros = torch.zeros_like(logits)
logits = logits[:, self.used, ...]
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
if self.remap is not None:
# go back to all entries but unused set to zero
full_zeros[:, self.used, ...] = soft_one_hot
soft_one_hot = full_zeros
z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = (
self.kl_weight
* torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
)
out_dict[self.loss_key] = diff
ind = soft_one_hot.argmax(dim=1)
out_dict["indices"] = ind
if self.remap is not None:
ind = self.remap_to_used(ind)
if return_logits:
out_dict["logits"] = logits
return z_q, out_dict
def get_codebook_entry(self, indices, shape):
# TODO: shape not yet optional
b, h, w, c = shape
assert b * h * w == indices.shape[0]
indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
if self.remap is not None:
indices = self.unmap_to_all(indices)
one_hot = (
F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
)
z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
return z_q
class VectorQuantizer(AbstractQuantizer):
"""
____________________________________________
Discretization bottleneck part of the VQ-VAE.
Inputs:
- n_e : number of embeddings
- e_dim : dimension of embedding
- beta : commitment cost used in loss term,
beta * ||z_e(x)-sg[e]||^2
_____________________________________________
"""
def __init__(
self,
n_e: int,
e_dim: int,
beta: float = 0.25,
remap: Optional[str] = None,
unknown_index: str = "random",
sane_index_shape: bool = False,
log_perplexity: bool = False,
embedding_weight_norm: bool = False,
loss_key: str = "loss/vq",
):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.loss_key = loss_key
if not embedding_weight_norm:
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
else:
self.embedding = torch.nn.utils.weight_norm(
nn.Embedding(self.n_e, self.e_dim), dim=1
)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
else:
self.used = None
self.re_embed = n_e
if unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
else:
assert unknown_index == "random" or isinstance(
unknown_index, int
), "unknown index needs to be 'random', 'extra' or any integer"
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.remap is not None:
logpy.info(
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices."
)
self.sane_index_shape = sane_index_shape
self.log_perplexity = log_perplexity
def forward(
self,
z: torch.Tensor,
) -> Tuple[torch.Tensor, Dict]:
do_reshape = z.ndim == 4
if do_reshape:
# # reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, "b c h w -> b h w c").contiguous()
else:
assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined"
z = z.contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2
* torch.einsum(
"bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")
)
)
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
loss_dict = {}
if self.log_perplexity:
perplexity, cluster_usage = measure_perplexity(
min_encoding_indices.detach(), self.n_e
)
loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage})
# compute loss for embedding
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
(z_q - z.detach()) ** 2
)
loss_dict[self.loss_key] = loss
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
if do_reshape:
z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(
z.shape[0], -1
) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape:
if do_reshape:
min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3]
)
else:
min_encoding_indices = rearrange(
min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]
)
loss_dict["min_encoding_indices"] = min_encoding_indices
return z_q, loss_dict
def get_codebook_entry(
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
) -> torch.Tensor:
# shape specifying (batch, height, width, channel)
if self.remap is not None:
assert shape is not None, "Need to give shape for remap"
indices = indices.reshape(shape[0], -1) # add batch axis
indices = self.unmap_to_all(indices)
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
class EmbeddingEMA(nn.Module):
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
super().__init__()
self.decay = decay
self.eps = eps
weight = torch.randn(num_tokens, codebook_dim)
self.weight = nn.Parameter(weight, requires_grad=False)
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
self.update = True
def forward(self, embed_id):
return F.embedding(embed_id, self.weight)
def cluster_size_ema_update(self, new_cluster_size):
self.cluster_size.data.mul_(self.decay).add_(
new_cluster_size, alpha=1 - self.decay
)
def embed_avg_ema_update(self, new_embed_avg):
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
def weight_update(self, num_tokens):
n = self.cluster_size.sum()
smoothed_cluster_size = (
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
)
# normalize embedding average with smoothed cluster size
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
self.weight.data.copy_(embed_normalized)
class EMAVectorQuantizer(AbstractQuantizer):
def __init__(
self,
n_embed: int,
embedding_dim: int,
beta: float,
decay: float = 0.99,
eps: float = 1e-5,
remap: Optional[str] = None,
unknown_index: str = "random",
loss_key: str = "loss/vq",
):
super().__init__()
self.codebook_dim = embedding_dim
self.num_tokens = n_embed
self.beta = beta
self.loss_key = loss_key
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
else:
self.used = None
self.re_embed = n_embed
if unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
else:
assert unknown_index == "random" or isinstance(
unknown_index, int
), "unknown index needs to be 'random', 'extra' or any integer"
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.remap is not None:
logpy.info(
f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices."
)
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
# reshape z -> (batch, height, width, channel) and flatten
# z, 'b c h w -> b h w c'
z = rearrange(z, "b c h w -> b h w c")
z_flattened = z.reshape(-1, self.codebook_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
z_flattened.pow(2).sum(dim=1, keepdim=True)
+ self.embedding.weight.pow(2).sum(dim=1)
- 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight)
) # 'n d -> d n'
encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(encoding_indices).view(z.shape)
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
if self.training and self.embedding.update:
# EMA cluster size
encodings_sum = encodings.sum(0)
self.embedding.cluster_size_ema_update(encodings_sum)
# EMA embedding average
embed_sum = encodings.transpose(0, 1) @ z_flattened
self.embedding.embed_avg_ema_update(embed_sum)
# normalize embed_avg and update weight
self.embedding.weight_update(self.num_tokens)
# compute loss for embedding
loss = self.beta * F.mse_loss(z_q.detach(), z)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
# z_q, 'b h w c -> b c h w'
z_q = rearrange(z_q, "b h w c -> b c h w")
out_dict = {
self.loss_key: loss,
"encodings": encodings,
"encoding_indices": encoding_indices,
"perplexity": perplexity,
}
return z_q, out_dict
class VectorQuantizerWithInputProjection(VectorQuantizer):
def __init__(
self,
input_dim: int,
n_codes: int,
codebook_dim: int,
beta: float = 1.0,
output_dim: Optional[int] = None,
**kwargs,
):
super().__init__(n_codes, codebook_dim, beta, **kwargs)
self.proj_in = nn.Linear(input_dim, codebook_dim)
self.output_dim = output_dim
if output_dim is not None:
self.proj_out = nn.Linear(codebook_dim, output_dim)
else:
self.proj_out = nn.Identity()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
rearr = False
in_shape = z.shape
if z.ndim > 3:
rearr = self.output_dim is not None
z = rearrange(z, "b c ... -> b (...) c")
z = self.proj_in(z)
z_q, loss_dict = super().forward(z)
z_q = self.proj_out(z_q)
if rearr:
if len(in_shape) == 4:
z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1])
elif len(in_shape) == 5:
z_q = rearrange(
z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2]
)
else:
raise NotImplementedError(
f"rearranging not available for {len(in_shape)}-dimensional input."
)
return z_q, loss_dict

View File

@@ -0,0 +1,349 @@
from typing import Callable, Iterable, Union
import torch
from einops import rearrange, repeat
from sgm.modules.diffusionmodules.model import (
XFORMERS_IS_AVAILABLE,
AttnBlock,
Decoder,
MemoryEfficientAttnBlock,
ResnetBlock,
)
from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding
from sgm.modules.video_attention import VideoTransformerBlock
from sgm.util import partialclass
class VideoResBlock(ResnetBlock):
def __init__(
self,
out_channels,
*args,
dropout=0.0,
video_kernel_size=3,
alpha=0.0,
merge_strategy="learned",
**kwargs,
):
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
if video_kernel_size is None:
video_kernel_size = [3, 1, 1]
self.time_stack = ResBlock(
channels=out_channels,
emb_channels=0,
dropout=dropout,
dims=3,
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=False,
skip_t_emb=True,
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, bs):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError()
def forward(self, x, temb, skip_video=False, timesteps=None):
if timesteps is None:
timesteps = self.timesteps
b, c, h, w = x.shape
x = super().forward(x, temb)
if not skip_video:
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_stack(x, temb)
alpha = self.get_alpha(bs=b // timesteps)
x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class AE3DConv(torch.nn.Conv2d):
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable):
padding = [int(k // 2) for k in video_kernel_size]
else:
padding = int(video_kernel_size // 2)
self.time_mix_conv = torch.nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=video_kernel_size,
padding=padding,
)
def forward(self, input, timesteps, skip_video=False):
x = super().forward(input)
if skip_video:
return x
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_mix_conv(x)
return rearrange(x, "b c t h w -> (b t) c h w")
class VideoBlock(AttnBlock):
def __init__(
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
):
super().__init__(in_channels)
# no context, single headed, as in base class
self.time_mix_block = VideoTransformerBlock(
dim=in_channels,
n_heads=1,
d_head=in_channels,
checkpoint=False,
ff_in=True,
attn_mode="softmax",
)
time_embed_dim = self.in_channels * 4
self.video_time_embed = torch.nn.Sequential(
torch.nn.Linear(self.in_channels, time_embed_dim),
torch.nn.SiLU(),
torch.nn.Linear(time_embed_dim, self.in_channels),
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def forward(self, x, timesteps, skip_video=False):
if skip_video:
return super().forward(x)
x_in = x
x = self.attention(x)
h, w = x.shape[2:]
x = rearrange(x, "b c h w -> b (h w) c")
x_mix = x
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
emb = self.video_time_embed(t_emb) # b, n_channels
emb = emb[:, None, :]
x_mix = x_mix + emb
alpha = self.get_alpha()
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x)
return x_in + x
def get_alpha(
self,
):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
def __init__(
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
):
super().__init__(in_channels)
# no context, single headed, as in base class
self.time_mix_block = VideoTransformerBlock(
dim=in_channels,
n_heads=1,
d_head=in_channels,
checkpoint=False,
ff_in=True,
attn_mode="softmax-xformers",
)
time_embed_dim = self.in_channels * 4
self.video_time_embed = torch.nn.Sequential(
torch.nn.Linear(self.in_channels, time_embed_dim),
torch.nn.SiLU(),
torch.nn.Linear(time_embed_dim, self.in_channels),
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def forward(self, x, timesteps, skip_time_block=False):
if skip_time_block:
return super().forward(x)
x_in = x
x = self.attention(x)
h, w = x.shape[2:]
x = rearrange(x, "b c h w -> b (h w) c")
x_mix = x
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
emb = self.video_time_embed(t_emb) # b, n_channels
emb = emb[:, None, :]
x_mix = x_mix + emb
alpha = self.get_alpha()
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x)
return x_in + x
def get_alpha(
self,
):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
def make_time_attn(
in_channels,
attn_type="vanilla",
attn_kwargs=None,
alpha: float = 0,
merge_strategy: str = "learned",
):
assert attn_type in [
"vanilla",
"vanilla-xformers",
], f"attn_type {attn_type} not supported for spatio-temporal attention"
print(
f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
)
if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
print(
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
)
attn_type = "vanilla"
if attn_type == "vanilla":
assert attn_kwargs is None
return partialclass(
VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
)
elif attn_type == "vanilla-xformers":
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return partialclass(
MemoryEfficientVideoBlock,
in_channels,
alpha=alpha,
merge_strategy=merge_strategy,
)
else:
return NotImplementedError()
class Conv2DWrapper(torch.nn.Conv2d):
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
return super().forward(input)
class VideoDecoder(Decoder):
available_time_modes = ["all", "conv-only", "attn-only"]
def __init__(
self,
*args,
video_kernel_size: Union[int, list] = 3,
alpha: float = 0.0,
merge_strategy: str = "learned",
time_mode: str = "conv-only",
**kwargs,
):
self.video_kernel_size = video_kernel_size
self.alpha = alpha
self.merge_strategy = merge_strategy
self.time_mode = time_mode
assert (
self.time_mode in self.available_time_modes
), f"time_mode parameter has to be in {self.available_time_modes}"
super().__init__(*args, **kwargs)
def get_last_layer(self, skip_time_mix=False, **kwargs):
if self.time_mode == "attn-only":
raise NotImplementedError("TODO")
else:
return (
self.conv_out.time_mix_conv.weight
if not skip_time_mix
else self.conv_out.weight
)
def _make_attn(self) -> Callable:
if self.time_mode not in ["conv-only", "only-last-conv"]:
return partialclass(
make_time_attn,
alpha=self.alpha,
merge_strategy=self.merge_strategy,
)
else:
return super()._make_attn()
def _make_conv(self) -> Callable:
if self.time_mode != "attn-only":
return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
else:
return Conv2DWrapper
def _make_resblock(self) -> Callable:
if self.time_mode not in ["attn-only", "only-last-conv"]:
return partialclass(
VideoResBlock,
video_kernel_size=self.video_kernel_size,
alpha=self.alpha,
merge_strategy=self.merge_strategy,
)
else:
return super()._make_resblock()

View File

@@ -1,7 +0,0 @@
from .denoiser import Denoiser
from .discretizer import Discretization
from .loss import StandardDiffusionLoss
from .model import Decoder, Encoder, Model
from .openaimodel import UNetModel
from .sampling import BaseDiffusionSampler
from .wrappers import OpenAIWrapper

View File

@@ -1,62 +1,74 @@
from typing import Dict, Union
import torch
import torch.nn as nn
from ...util import append_dims, instantiate_from_config
from .denoiser_scaling import DenoiserScaling
from .discretizer import Discretization
class Denoiser(nn.Module):
def __init__(self, weighting_config, scaling_config):
def __init__(self, scaling_config: Dict):
super().__init__()
self.weighting = instantiate_from_config(weighting_config)
self.scaling = instantiate_from_config(scaling_config)
self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)
def possibly_quantize_sigma(self, sigma):
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
return sigma
def possibly_quantize_c_noise(self, c_noise):
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
return c_noise
def w(self, sigma):
return self.weighting(sigma)
def __call__(self, network, input, sigma, cond):
def forward(
self,
network: nn.Module,
input: torch.Tensor,
sigma: torch.Tensor,
cond: Dict,
**additional_model_inputs,
) -> torch.Tensor:
sigma = self.possibly_quantize_sigma(sigma)
sigma_shape = sigma.shape
sigma = append_dims(sigma, input.ndim)
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
return network(input * c_in, c_noise, cond) * c_out + input * c_skip
return (
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
+ input * c_skip
)
class DiscreteDenoiser(Denoiser):
def __init__(
self,
weighting_config,
scaling_config,
num_idx,
discretization_config,
do_append_zero=False,
quantize_c_noise=True,
flip=True,
scaling_config: Dict,
num_idx: int,
discretization_config: Dict,
do_append_zero: bool = False,
quantize_c_noise: bool = True,
flip: bool = True,
):
super().__init__(weighting_config, scaling_config)
sigmas = instantiate_from_config(discretization_config)(
num_idx, do_append_zero=do_append_zero, flip=flip
super().__init__(scaling_config)
self.discretization: Discretization = instantiate_from_config(
discretization_config
)
sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip)
self.register_buffer("sigmas", sigmas)
self.quantize_c_noise = quantize_c_noise
self.num_idx = num_idx
def sigma_to_idx(self, sigma):
def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor:
dists = sigma - self.sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape)
def idx_to_sigma(self, idx):
def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor:
return self.sigmas[idx]
def possibly_quantize_sigma(self, sigma):
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
return self.idx_to_sigma(self.sigma_to_idx(sigma))
def possibly_quantize_c_noise(self, c_noise):
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
if self.quantize_c_noise:
return self.sigma_to_idx(c_noise)
else:

View File

@@ -1,11 +1,24 @@
from abc import ABC, abstractmethod
from typing import Tuple
import torch
class DenoiserScaling(ABC):
@abstractmethod
def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
pass
class EDMScaling:
def __init__(self, sigma_data=0.5):
def __init__(self, sigma_data: float = 0.5):
self.sigma_data = sigma_data
def __call__(self, sigma):
def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
@@ -14,7 +27,9 @@ class EDMScaling:
class EpsScaling:
def __call__(self, sigma):
def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = torch.ones_like(sigma, device=sigma.device)
c_out = -sigma
c_in = 1 / (sigma**2 + 1.0) ** 0.5
@@ -23,9 +38,22 @@ class EpsScaling:
class VScaling:
def __call__(self, sigma):
def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = 1.0 / (sigma**2 + 1.0)
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
c_noise = sigma.clone()
return c_skip, c_out, c_in, c_noise
class VScalingWithEDMcNoise(DenoiserScaling):
def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = 1.0 / (sigma**2 + 1.0)
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
c_noise = 0.25 * sigma.log()
return c_skip, c_out, c_in, c_noise

View File

@@ -1,31 +1,33 @@
from functools import partial
import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union
import torch
from einops import rearrange, repeat
from ...util import default, instantiate_from_config
from ...util import append_dims, default
logpy = logging.getLogger(__name__)
class VanillaCFG:
"""
implements parallelized CFG
"""
class Guider(ABC):
@abstractmethod
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
pass
def __init__(self, scale, dyn_thresh_config=None):
scale_schedule = lambda scale, sigma: scale # independent of step
self.scale_schedule = partial(scale_schedule, scale)
self.dyn_thresh = instantiate_from_config(
default(
dyn_thresh_config,
{
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
},
)
)
def prepare_inputs(
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
) -> Tuple[torch.Tensor, float, Dict]:
pass
def __call__(self, x, sigma):
class VanillaCFG(Guider):
def __init__(self, scale: float):
self.scale = scale
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
scale_value = self.scale_schedule(sigma)
x_pred = self.dyn_thresh(x_u, x_c, scale_value)
x_pred = x_u + self.scale * (x_c - x_u)
return x_pred
def prepare_inputs(self, x, s, c, uc):
@@ -40,14 +42,58 @@ class VanillaCFG:
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
class IdentityGuider:
def __call__(self, x, sigma):
class IdentityGuider(Guider):
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
return x
def prepare_inputs(self, x, s, c, uc):
def prepare_inputs(
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
) -> Tuple[torch.Tensor, float, Dict]:
c_out = dict()
for k in c:
c_out[k] = c[k]
return x, s, c_out
class LinearPredictionGuider(Guider):
def __init__(
self,
max_scale: float,
num_frames: int,
min_scale: float = 1.0,
additional_cond_keys: Optional[Union[List[str], str]] = None,
):
self.min_scale = min_scale
self.max_scale = max_scale
self.num_frames = num_frames
self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
additional_cond_keys = default(additional_cond_keys, [])
if isinstance(additional_cond_keys, str):
additional_cond_keys = [additional_cond_keys]
self.additional_cond_keys = additional_cond_keys
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
scale = append_dims(scale, x_u.ndim).to(x_u.device)
return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
def prepare_inputs(
self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
c_out = dict()
for k in c:
if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out

View File

@@ -1,31 +1,34 @@
from typing import List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from omegaconf import ListConfig
from ...util import append_dims, instantiate_from_config
from ...modules.autoencoding.lpips.loss.lpips import LPIPS
from ...modules.encoders.modules import GeneralConditioner
from ...util import append_dims, instantiate_from_config
from .denoiser import Denoiser
class StandardDiffusionLoss(nn.Module):
def __init__(
self,
sigma_sampler_config,
type="l2",
offset_noise_level=0.0,
batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
sigma_sampler_config: dict,
loss_weighting_config: dict,
loss_type: str = "l2",
offset_noise_level: float = 0.0,
batch2model_keys: Optional[Union[str, List[str]]] = None,
):
super().__init__()
assert type in ["l2", "l1", "lpips"]
assert loss_type in ["l2", "l1", "lpips"]
self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
self.loss_weighting = instantiate_from_config(loss_weighting_config)
self.type = type
self.loss_type = loss_type
self.offset_noise_level = offset_noise_level
if type == "lpips":
if loss_type == "lpips":
self.lpips = LPIPS().eval()
if not batch2model_keys:
@@ -36,34 +39,67 @@ class StandardDiffusionLoss(nn.Module):
self.batch2model_keys = set(batch2model_keys)
def __call__(self, network, denoiser, conditioner, input, batch):
def get_noised_input(
self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor
) -> torch.Tensor:
noised_input = input + noise * sigmas_bc
return noised_input
def forward(
self,
network: nn.Module,
denoiser: Denoiser,
conditioner: GeneralConditioner,
input: torch.Tensor,
batch: Dict,
) -> torch.Tensor:
cond = conditioner(batch)
return self._forward(network, denoiser, cond, input, batch)
def _forward(
self,
network: nn.Module,
denoiser: Denoiser,
cond: Dict,
input: torch.Tensor,
batch: Dict,
) -> Tuple[torch.Tensor, Dict]:
additional_model_inputs = {
key: batch[key] for key in self.batch2model_keys.intersection(batch)
}
sigmas = self.sigma_sampler(input.shape[0]).to(input)
sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
noise = torch.randn_like(input)
if self.offset_noise_level > 0.0:
noise = noise + self.offset_noise_level * append_dims(
torch.randn(input.shape[0], device=input.device), input.ndim
offset_shape = (
(input.shape[0], 1, input.shape[2])
if self.n_frames is not None
else (input.shape[0], input.shape[1])
)
noised_input = input + noise * append_dims(sigmas, input.ndim)
noise = noise + self.offset_noise_level * append_dims(
torch.randn(offset_shape, device=input.device),
input.ndim,
)
sigmas_bc = append_dims(sigmas, input.ndim)
noised_input = self.get_noised_input(sigmas_bc, noise, input)
model_output = denoiser(
network, noised_input, sigmas, cond, **additional_model_inputs
)
w = append_dims(denoiser.w(sigmas), input.ndim)
w = append_dims(self.loss_weighting(sigmas), input.ndim)
return self.get_loss(model_output, input, w)
def get_loss(self, model_output, target, w):
if self.type == "l2":
if self.loss_type == "l2":
return torch.mean(
(w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
)
elif self.type == "l1":
elif self.loss_type == "l1":
return torch.mean(
(w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
)
elif self.type == "lpips":
elif self.loss_type == "lpips":
loss = self.lpips(model_output, target).reshape(-1)
return loss
else:
raise NotImplementedError(f"Unknown loss type {self.loss_type}")

View File

@@ -0,0 +1,32 @@
from abc import ABC, abstractmethod
import torch
class DiffusionLossWeighting(ABC):
@abstractmethod
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
pass
class UnitWeighting(DiffusionLossWeighting):
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
return torch.ones_like(sigma, device=sigma.device)
class EDMWeighting(DiffusionLossWeighting):
def __init__(self, sigma_data: float = 0.5):
self.sigma_data = sigma_data
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
class VWeighting(EDMWeighting):
def __init__(self):
super().__init__(sigma_data=1.0)
class EpsWeighting(DiffusionLossWeighting):
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
return sigma**-2.0

View File

@@ -1,4 +1,5 @@
# pytorch_diffusion + derived encoder decoder
import logging
import math
from typing import Any, Callable, Optional
@@ -8,6 +9,8 @@ import torch.nn as nn
from einops import rearrange
from packaging import version
logpy = logging.getLogger(__name__)
try:
import xformers
import xformers.ops
@@ -15,7 +18,7 @@ try:
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
print("no module 'xformers'. Processing without...")
logpy.warning("no module 'xformers'. Processing without...")
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
@@ -288,12 +291,14 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
)
attn_type = "vanilla-xformers"
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
return AttnBlock(in_channels)
elif attn_type == "vanilla-xformers":
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
logpy.info(
f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
)
return MemoryEfficientAttnBlock(in_channels)
elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
@@ -633,7 +638,7 @@ class Decoder(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print(
logpy.info(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)

File diff suppressed because it is too large Load Diff

View File

@@ -9,13 +9,10 @@ import torch
from omegaconf import ListConfig, OmegaConf
from tqdm import tqdm
from ...modules.diffusionmodules.sampling_utils import (
get_ancestral_step,
linear_multistep_coeff,
to_d,
to_neg_log_sigma,
to_sigma,
)
from ...modules.diffusionmodules.sampling_utils import (get_ancestral_step,
linear_multistep_coeff,
to_d, to_neg_log_sigma,
to_sigma)
from ...util import append_dims, default, instantiate_from_config
DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}

View File

@@ -4,11 +4,6 @@ from scipy import integrate
from ...util import append_dims
class NoDynamicThresholding:
def __call__(self, uncond, cond, scale):
return uncond + scale * (cond - uncond)
def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
if order - 1 > i:
raise ValueError(f"Order {order} too high for step {i}")

View File

@@ -1,5 +1,5 @@
"""
adopted from
partially adopted from
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
and
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
@@ -10,10 +10,11 @@ thanks!
"""
import math
from typing import Optional
import torch
import torch.nn as nn
from einops import repeat
from einops import rearrange, repeat
def make_beta_schedule(
@@ -306,3 +307,63 @@ def avg_pool_nd(dims, *args, **kwargs):
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
def __init__(
self,
alpha: float,
merge_strategy: str = "learned_with_images",
rearrange_pattern: str = "b t -> (b t) 1 1",
):
super().__init__()
self.merge_strategy = merge_strategy
self.rearrange_pattern = rearrange_pattern
assert (
merge_strategy in self.strategies
), f"merge_strategy needs to be in {self.strategies}"
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif (
self.merge_strategy == "learned"
or self.merge_strategy == "learned_with_images"
):
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
if self.merge_strategy == "fixed":
alpha = self.mix_factor
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
elif self.merge_strategy == "learned_with_images":
assert image_only_indicator is not None, "need image_only_indicator ..."
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
)
alpha = rearrange(alpha, self.rearrange_pattern)
else:
raise NotImplementedError
return alpha
def forward(
self,
x_spatial: torch.Tensor,
x_temporal: torch.Tensor,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator)
x = (
alpha.to(x_spatial.dtype) * x_spatial
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
)
return x

View File

@@ -0,0 +1,493 @@
from functools import partial
from typing import List, Optional, Union
from einops import rearrange
from ...modules.diffusionmodules.openaimodel import *
from ...modules.video_attention import SpatialVideoTransformer
from ...util import default
from .util import AlphaBlender
class VideoResBlock(ResBlock):
def __init__(
self,
channels: int,
emb_channels: int,
dropout: float,
video_kernel_size: Union[int, List[int]] = 3,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
out_channels: Optional[int] = None,
use_conv: bool = False,
use_scale_shift_norm: bool = False,
dims: int = 2,
use_checkpoint: bool = False,
up: bool = False,
down: bool = False,
):
super().__init__(
channels,
emb_channels,
dropout,
out_channels=out_channels,
use_conv=use_conv,
use_scale_shift_norm=use_scale_shift_norm,
dims=dims,
use_checkpoint=use_checkpoint,
up=up,
down=down,
)
self.time_stack = ResBlock(
default(out_channels, channels),
emb_channels,
dropout=dropout,
dims=3,
out_channels=default(out_channels, channels),
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=use_checkpoint,
exchange_temb_dims=True,
)
self.time_mixer = AlphaBlender(
alpha=merge_factor,
merge_strategy=merge_strategy,
rearrange_pattern="b t -> b 1 t 1 1",
)
def forward(
self,
x: th.Tensor,
emb: th.Tensor,
num_video_frames: int,
image_only_indicator: Optional[th.Tensor] = None,
) -> th.Tensor:
x = super().forward(x, emb)
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = self.time_stack(
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
)
x = self.time_mixer(
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
)
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class VideoUNet(nn.Module):
def __init__(
self,
in_channels: int,
model_channels: int,
out_channels: int,
num_res_blocks: int,
attention_resolutions: int,
dropout: float = 0.0,
channel_mult: List[int] = (1, 2, 4, 8),
conv_resample: bool = True,
dims: int = 2,
num_classes: Optional[int] = None,
use_checkpoint: bool = False,
num_heads: int = -1,
num_head_channels: int = -1,
num_heads_upsample: int = -1,
use_scale_shift_norm: bool = False,
resblock_updown: bool = False,
transformer_depth: Union[List[int], int] = 1,
transformer_depth_middle: Optional[int] = None,
context_dim: Optional[int] = None,
time_downup: bool = False,
time_context_dim: Optional[int] = None,
extra_ff_mix_layer: bool = False,
use_spatial_context: bool = False,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
spatial_transformer_attn_type: str = "softmax",
video_kernel_size: Union[int, List[int]] = 3,
use_linear_in_transformer: bool = False,
adm_in_channels: Optional[int] = None,
disable_temporal_crossattention: bool = False,
max_ddpm_temb_period: int = 10000,
):
super().__init__()
assert context_dim is not None
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1
if num_head_channels == -1:
assert num_heads != -1
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
transformer_depth_middle = default(
transformer_depth_middle, transformer_depth[-1]
)
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "timestep":
self.label_emb = nn.Sequential(
Timestep(model_channels),
nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
),
)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
else:
raise ValueError()
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
def get_attention_layer(
ch,
num_heads,
dim_head,
depth=1,
context_dim=None,
use_checkpoint=False,
disabled_sa=False,
):
return SpatialVideoTransformer(
ch,
num_heads,
dim_head,
depth=depth,
context_dim=context_dim,
time_context_dim=time_context_dim,
dropout=dropout,
ff_in=extra_ff_mix_layer,
use_spatial_context=use_spatial_context,
merge_strategy=merge_strategy,
merge_factor=merge_factor,
checkpoint=use_checkpoint,
use_linear=use_linear_in_transformer,
attn_mode=spatial_transformer_attn_type,
disable_self_attn=disabled_sa,
disable_temporal_crossattention=disable_temporal_crossattention,
max_time_embed_period=max_ddpm_temb_period,
)
def get_resblock(
merge_factor,
merge_strategy,
video_kernel_size,
ch,
time_embed_dim,
dropout,
out_ch,
dims,
use_checkpoint,
use_scale_shift_norm,
down=False,
up=False,
):
return VideoResBlock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
channels=ch,
emb_channels=time_embed_dim,
dropout=dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=down,
up=up,
)
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_ch=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers.append(
get_attention_layer(
ch,
num_heads,
dim_head,
depth=transformer_depth[level],
context_dim=context_dim,
use_checkpoint=use_checkpoint,
disabled_sa=False,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
ds *= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_ch=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch,
conv_resample,
dims=dims,
out_channels=out_ch,
third_down=time_downup,
)
)
)
ch = out_ch
input_block_chans.append(ch)
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
self.middle_block = TimestepEmbedSequential(
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
out_ch=None,
dropout=dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
get_attention_layer(
ch,
num_heads,
dim_head,
depth=transformer_depth_middle,
context_dim=context_dim,
use_checkpoint=use_checkpoint,
),
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
out_ch=None,
time_embed_dim=time_embed_dim,
dropout=dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch + ich,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_ch=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers.append(
get_attention_layer(
ch,
num_heads,
dim_head,
depth=transformer_depth[level],
context_dim=context_dim,
use_checkpoint=use_checkpoint,
disabled_sa=False,
)
)
if level and i == num_res_blocks:
out_ch = ch
ds //= 2
layers.append(
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_ch=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
if resblock_updown
else Upsample(
ch,
conv_resample,
dims=dims,
out_channels=out_ch,
third_up=time_downup,
)
)
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
def forward(
self,
x: th.Tensor,
timesteps: th.Tensor,
context: Optional[th.Tensor] = None,
y: Optional[th.Tensor] = None,
time_context: Optional[th.Tensor] = None,
num_video_frames: Optional[int] = None,
image_only_indicator: Optional[th.Tensor] = None,
):
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x
for module in self.input_blocks:
h = module(
h,
emb,
context=context,
image_only_indicator=image_only_indicator,
time_context=time_context,
num_video_frames=num_video_frames,
)
hs.append(h)
h = self.middle_block(
h,
emb,
context=context,
image_only_indicator=image_only_indicator,
time_context=time_context,
num_video_frames=num_video_frames,
)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(
h,
emb,
context=context,
image_only_indicator=image_only_indicator,
time_context=time_context,
num_video_frames=num_video_frames,
)
h = h.type(x.dtype)
return self.out(h)

View File

@@ -1,3 +1,4 @@
import math
from contextlib import nullcontext
from functools import partial
from typing import Dict, List, Optional, Tuple, Union
@@ -10,27 +11,17 @@ import torch.nn as nn
from einops import rearrange, repeat
from omegaconf import ListConfig
from torch.utils.checkpoint import checkpoint
from transformers import (
ByT5Tokenizer,
CLIPTextModel,
CLIPTokenizer,
T5EncoderModel,
T5Tokenizer,
)
from transformers import (ByT5Tokenizer, CLIPTextModel, CLIPTokenizer,
T5EncoderModel, T5Tokenizer)
from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
from ...modules.diffusionmodules.model import Encoder
from ...modules.diffusionmodules.openaimodel import Timestep
from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
from ...modules.diffusionmodules.util import (extract_into_tensor,
make_beta_schedule)
from ...modules.distributions.distributions import DiagonalGaussianDistribution
from ...util import (
autocast,
count_params,
default,
disabled_train,
expand_dims_like,
instantiate_from_config,
)
from ...util import (append_dims, autocast, count_params, default,
disabled_train, expand_dims_like, instantiate_from_config)
class AbstractEmbModel(nn.Module):
@@ -173,7 +164,11 @@ class GeneralConditioner(nn.Module):
return output
def get_unconditional_conditioning(
self, batch_c, batch_uc=None, force_uc_zero_embeddings=None
self,
batch_c: Dict,
batch_uc: Optional[Dict] = None,
force_uc_zero_embeddings: Optional[List[str]] = None,
force_cond_zero_embeddings: Optional[List[str]] = None,
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
@@ -181,7 +176,7 @@ class GeneralConditioner(nn.Module):
for embedder in self.embedders:
ucg_rates.append(embedder.ucg_rate)
embedder.ucg_rate = 0.0
c = self(batch_c)
c = self(batch_c, force_cond_zero_embeddings)
uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
for embedder, rate in zip(self.embedders, ucg_rates):
@@ -201,12 +196,6 @@ class InceptionV3(nn.Module):
self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)
def forward(self, inp):
# inp = kornia.geometry.resize(inp, (299, 299),
# interpolation='bicubic',
# align_corners=False,
# antialias=True)
# inp = inp.clamp(min=-1, max=1)
outp = self.model(inp)
if len(outp) == 1:
@@ -277,7 +266,6 @@ class FrozenT5Embedder(AbstractEmbModel):
for param in self.parameters():
param.requires_grad = False
# @autocast
def forward(self, text):
batch_encoding = self.tokenizer(
text,
@@ -597,11 +585,12 @@ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
repeat_to_max_len=False,
num_image_crops=0,
output_tokens=False,
init_device=None,
):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device("cpu"),
device=torch.device(default(init_device, "cpu")),
pretrained=version,
)
del model.transformer
@@ -914,7 +903,6 @@ class LowScaleEncoder(nn.Module):
z = self.q_sample(z, noise_level)
if self.out_size is not None:
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
return z, noise_level
def decode(self, z):
@@ -958,3 +946,101 @@ class GaussianEncoder(Encoder, AbstractEmbModel):
if self.flatten_output:
z = rearrange(z, "b c h w -> b (h w ) c")
return log, z
class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
def __init__(
self,
n_cond_frames: int,
n_copies: int,
encoder_config: dict,
sigma_sampler_config: Optional[dict] = None,
sigma_cond_config: Optional[dict] = None,
is_ae: bool = False,
scale_factor: float = 1.0,
disable_encoder_autocast: bool = False,
en_and_decode_n_samples_a_time: Optional[int] = None,
):
super().__init__()
self.n_cond_frames = n_cond_frames
self.n_copies = n_copies
self.encoder = instantiate_from_config(encoder_config)
self.sigma_sampler = (
instantiate_from_config(sigma_sampler_config)
if sigma_sampler_config is not None
else None
)
self.sigma_cond = (
instantiate_from_config(sigma_cond_config)
if sigma_cond_config is not None
else None
)
self.is_ae = is_ae
self.scale_factor = scale_factor
self.disable_encoder_autocast = disable_encoder_autocast
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
def forward(
self, vid: torch.Tensor
) -> Union[
torch.Tensor,
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, dict],
Tuple[Tuple[torch.Tensor, torch.Tensor], dict],
]:
if self.sigma_sampler is not None:
b = vid.shape[0] // self.n_cond_frames
sigmas = self.sigma_sampler(b).to(vid.device)
if self.sigma_cond is not None:
sigma_cond = self.sigma_cond(sigmas)
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies)
sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames)
noise = torch.randn_like(vid)
vid = vid + noise * append_dims(sigmas, vid.ndim)
with torch.autocast("cuda", enabled=not self.disable_encoder_autocast):
n_samples = (
self.en_and_decode_n_samples_a_time
if self.en_and_decode_n_samples_a_time is not None
else vid.shape[0]
)
n_rounds = math.ceil(vid.shape[0] / n_samples)
all_out = []
for n in range(n_rounds):
if self.is_ae:
out = self.encoder.encode(vid[n * n_samples : (n + 1) * n_samples])
else:
out = self.encoder(vid[n * n_samples : (n + 1) * n_samples])
all_out.append(out)
vid = torch.cat(all_out, dim=0)
vid *= self.scale_factor
vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames)
vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies)
return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid
return return_val
class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel):
def __init__(
self,
open_clip_embedding_config: Dict,
n_cond_frames: int,
n_copies: int,
):
super().__init__()
self.n_cond_frames = n_cond_frames
self.n_copies = n_copies
self.open_clip = instantiate_from_config(open_clip_embedding_config)
def forward(self, vid):
vid = self.open_clip(vid)
vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames)
vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies)
return vid

View File

@@ -0,0 +1,301 @@
import torch
from ..modules.attention import *
from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding
class TimeMixSequential(nn.Sequential):
def forward(self, x, context=None, timesteps=None):
for layer in self:
x = layer(x, context, timesteps)
return x
class VideoTransformerBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention,
"softmax-xformers": MemoryEfficientCrossAttention,
}
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
timesteps=None,
ff_in=False,
inner_dim=None,
attn_mode="softmax",
disable_self_attn=False,
disable_temporal_crossattention=False,
switch_temporal_ca_to_sa=False,
):
super().__init__()
attn_cls = self.ATTENTION_MODES[attn_mode]
self.ff_in = ff_in or inner_dim is not None
if inner_dim is None:
inner_dim = dim
assert int(n_heads * d_head) == inner_dim
self.is_res = inner_dim == dim
if self.ff_in:
self.norm_in = nn.LayerNorm(dim)
self.ff_in = FeedForward(
dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff
)
self.timesteps = timesteps
self.disable_self_attn = disable_self_attn
if self.disable_self_attn:
self.attn1 = attn_cls(
query_dim=inner_dim,
heads=n_heads,
dim_head=d_head,
context_dim=context_dim,
dropout=dropout,
) # is a cross-attention
else:
self.attn1 = attn_cls(
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)
if disable_temporal_crossattention:
if switch_temporal_ca_to_sa:
raise ValueError
else:
self.attn2 = None
else:
self.norm2 = nn.LayerNorm(inner_dim)
if switch_temporal_ca_to_sa:
self.attn2 = attn_cls(
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
else:
self.attn2 = attn_cls(
query_dim=inner_dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(inner_dim)
self.norm3 = nn.LayerNorm(inner_dim)
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
self.checkpoint = checkpoint
if self.checkpoint:
print(f"{self.__class__.__name__} is using checkpointing")
def forward(
self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None
) -> torch.Tensor:
if self.checkpoint:
return checkpoint(self._forward, x, context, timesteps)
else:
return self._forward(x, context, timesteps=timesteps)
def _forward(self, x, context=None, timesteps=None):
assert self.timesteps or timesteps
assert not (self.timesteps and timesteps) or self.timesteps == timesteps
timesteps = self.timesteps or timesteps
B, S, C = x.shape
x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
if self.ff_in:
x_skip = x
x = self.ff_in(self.norm_in(x))
if self.is_res:
x += x_skip
if self.disable_self_attn:
x = self.attn1(self.norm1(x), context=context) + x
else:
x = self.attn1(self.norm1(x)) + x
if self.attn2 is not None:
if self.switch_temporal_ca_to_sa:
x = self.attn2(self.norm2(x)) + x
else:
x = self.attn2(self.norm2(x), context=context) + x
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
x = rearrange(
x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)
return x
def get_last_layer(self):
return self.ff.net[-1].weight
class SpatialVideoTransformer(SpatialTransformer):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
use_linear=False,
context_dim=None,
use_spatial_context=False,
timesteps=None,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
time_context_dim=None,
ff_in=False,
checkpoint=False,
time_depth=1,
attn_mode="softmax",
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
):
super().__init__(
in_channels,
n_heads,
d_head,
depth=depth,
dropout=dropout,
attn_type=attn_mode,
use_checkpoint=checkpoint,
context_dim=context_dim,
use_linear=use_linear,
disable_self_attn=disable_self_attn,
)
self.time_depth = time_depth
self.depth = depth
self.max_time_embed_period = max_time_embed_period
time_mix_d_head = d_head
n_time_mix_heads = n_heads
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
inner_dim = n_heads * d_head
if use_spatial_context:
time_context_dim = context_dim
self.time_stack = nn.ModuleList(
[
VideoTransformerBlock(
inner_dim,
n_time_mix_heads,
time_mix_d_head,
dropout=dropout,
context_dim=time_context_dim,
timesteps=timesteps,
checkpoint=checkpoint,
ff_in=ff_in,
inner_dim=time_mix_inner_dim,
attn_mode=attn_mode,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
)
for _ in range(self.depth)
]
)
assert len(self.time_stack) == len(self.transformer_blocks)
self.use_spatial_context = use_spatial_context
self.in_channels = in_channels
time_embed_dim = self.in_channels * 4
self.time_pos_embed = nn.Sequential(
linear(self.in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, self.in_channels),
)
self.time_mixer = AlphaBlender(
alpha=merge_factor, merge_strategy=merge_strategy
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.Tensor:
_, _, h, w = x.shape
x_in = x
spatial_context = None
if exists(context):
spatial_context = context
if self.use_spatial_context:
assert (
context.ndim == 3
), f"n dims of spatial context should be 3 but are {context.ndim}"
time_context = context
time_context_first_timestep = time_context[::timesteps]
time_context = repeat(
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
)
elif time_context is not None and not self.use_spatial_context:
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
if time_context.ndim == 2:
time_context = rearrange(time_context, "b c -> b 1 c")
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c")
if self.use_linear:
x = self.proj_in(x)
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(
num_frames,
self.in_channels,
repeat_only=False,
max_period=self.max_time_embed_period,
)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]
for it_, (block, mix_block) in enumerate(
zip(self.transformer_blocks, self.time_stack)
):
x = block(
x,
context=spatial_context,
)
x_mix = x
x_mix = x_mix + emb
x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)
x = self.time_mixer(
x_spatial=x,
x_temporal=x_mix,
image_only_indicator=image_only_indicator,
)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
if not self.use_linear:
x = self.proj_out(x)
out = x + x_in
return out

View File

@@ -246,3 +246,30 @@ def get_configs_path() -> str:
if os.path.isdir(candidate):
return candidate
raise FileNotFoundError(f"Could not find SGM configs in {candidates}")
def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
"""
Will return the result of a recursive get attribute call.
E.g.:
a.b.c
= getattr(getattr(a, "b"), "c")
= get_nested_attribute(a, "b.c")
If any part of the attribute call is an integer x with current obj a, will
try to call a[x] instead of a.x first.
"""
attributes = attribute_path.split(".")
if depth is not None and depth > 0:
attributes = attributes[:depth]
assert len(attributes) > 0, "At least one attribute should be selected"
current_attribute = obj
current_key = None
for level, attribute in enumerate(attributes):
current_key = ".".join(attributes[: level + 1])
try:
id_ = int(attribute)
current_attribute = current_attribute[id_]
except ValueError:
current_attribute = getattr(current_attribute, attribute)
return (current_attribute, current_key) if return_key else current_attribute