Stable Video Diffusion

This commit is contained in:
Tim Dockhorn
2023-11-21 10:40:21 -08:00
parent 477d8b9a77
commit 059d8e9cd9
59 changed files with 5463 additions and 1691 deletions

119
README.md
View File

@@ -4,26 +4,48 @@
## News ## News
**November 21, 2023**
- We are releasing Stable Video Diffusion, an image-to-video model, for research purposes:
- [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid): This model was trained to generate 14
frames at resolution 576x1024 given a context frame of the same size.
We use the standard image encoder from SD 2.1, but replace the decoder with a temporally-aware `deflickering decoder`.
- [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt): Same architecture as `SVD` but finetuned
for 25 frame generation.
- We provide a streamlit demo `scripts/demo/video_sampling.py` and a standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models.
- Alongside the model, we release a [technical report](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets).
![tile](assets/tile.gif)
**July 26, 2023** **July 26, 2023**
- We are releasing two new open models with a permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file hashes):
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version over `SDXL-base-0.9`. - We are releasing two new open models with a
- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version over `SDXL-refiner-0.9`. permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file
hashes):
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version
over `SDXL-base-0.9`.
- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version
over `SDXL-refiner-0.9`.
![sample2](assets/001_with_eval.png) ![sample2](assets/001_with_eval.png)
**July 4, 2023** **July 4, 2023**
- A technical report on SDXL is now available [here](https://arxiv.org/abs/2307.01952). - A technical report on SDXL is now available [here](https://arxiv.org/abs/2307.01952).
**June 22, 2023** **June 22, 2023**
- We are releasing two new diffusion models for research purposes: - We are releasing two new diffusion models for research purposes:
- `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip) and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses the OpenCLIP model. - `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The
- `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model. base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip)
and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses
the OpenCLIP model.
- `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is
not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
If you would like to access these models for your research, please apply using one of the following links: If you would like to access these models for your research, please apply using one of the following links:
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9). [SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
This means that you can apply for any of the two links - and if you are granted - you can access both. This means that you can apply for any of the two links - and if you are granted - you can access both.
Please log in to your Hugging Face Account with your organization email to request access. Please log in to your Hugging Face Account with your organization email to request access.
**We plan to do a full release soon (July).** **We plan to do a full release soon (July).**
@@ -32,21 +54,32 @@ Please log in to your Hugging Face Account with your organization email to reque
### General Philosophy ### General Philosophy
Modularity is king. This repo implements a config-driven approach where we build and combine submodules by calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples. Modularity is king. This repo implements a config-driven approach where we build and combine submodules by
calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
### Changelog from the old `ldm` codebase ### Changelog from the old `ldm` codebase
For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`, now `DiffusionEngine`) has been cleaned up: For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other
training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`,
now `DiffusionEngine`) has been cleaned up:
- No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial conditionings, and all combinations thereof) in a single class: `GeneralConditioner`, see `sgm/modules/encoders/modules.py`. - No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial
conditionings, and all combinations thereof) in a single class: `GeneralConditioner`,
see `sgm/modules/encoders/modules.py`.
- We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the - We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model. samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model.
- We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable change is probably now the option to train continuous time models): - We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable
* Discrete times models (denoisers) are simply a special case of continuous time models (denoisers); see `sgm/modules/diffusionmodules/denoiser.py`. change is probably now the option to train continuous time models):
* The following features are now independent: weighting of the diffusion loss function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during training (`sgm/modules/diffusionmodules/sigma_sampling.py`). * Discrete times models (denoisers) are simply a special case of continuous time models (denoisers);
see `sgm/modules/diffusionmodules/denoiser.py`.
* The following features are now independent: weighting of the diffusion loss
function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the
network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during
training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
- Autoencoding models have also been cleaned up. - Autoencoding models have also been cleaned up.
## Installation: ## Installation:
<a name="installation"></a> <a name="installation"></a>
#### 1. Clone the repo #### 1. Clone the repo
@@ -60,21 +93,10 @@ cd generative-models
This is assuming you have navigated to the `generative-models` root after cloning it. This is assuming you have navigated to the `generative-models` root after cloning it.
**NOTE:** This is tested under `python3.8` and `python3.10`. For other python versions, you might encounter version conflicts. **NOTE:** This is tested under `python3.10`. For other python versions, you might encounter version conflicts.
**PyTorch 1.13**
```shell
# install required packages from pypi
python3 -m venv .pt13
source .pt13/bin/activate
pip3 install -r requirements/pt13.txt
```
**PyTorch 2.0** **PyTorch 2.0**
```shell ```shell
# install required packages from pypi # install required packages from pypi
python3 -m venv .pt2 python3 -m venv .pt2
@@ -82,7 +104,6 @@ source .pt2/bin/activate
pip3 install -r requirements/pt2.txt pip3 install -r requirements/pt2.txt
``` ```
#### 3. Install `sgm` #### 3. Install `sgm`
```shell ```shell
@@ -114,8 +135,10 @@ depending on your use case and PyTorch version, manually.
## Inference ## Inference
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`. We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling
We provide file hashes for the complete file as well as for only the saved tensors in the file (see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that). in `scripts/demo/sampling.py`.
We provide file hashes for the complete file as well as for only the saved tensors in the file (
see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
The following models are currently supported: The following models are currently supported:
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) - [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
@@ -136,19 +159,20 @@ The following models are currently supported:
**Weights for SDXL**: **Weights for SDXL**:
**SDXL-1.0:** **SDXL-1.0:**
The weights of SDXL-1.0 are available (subject to a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here: The weights of SDXL-1.0 are available (subject to
a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here:
- base model: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/ - base model: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/
- refiner model: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/ - refiner model: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/
**SDXL-0.9:** **SDXL-0.9:**
The weights of SDXL-0.9 are available and subject to a [research license](model_licenses/LICENSE-SDXL0.9). The weights of SDXL-0.9 are available and subject to a [research license](model_licenses/LICENSE-SDXL0.9).
If you would like to access these models for your research, please apply using one of the following links: If you would like to access these models for your research, please apply using one of the following links:
[SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9). [SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
This means that you can apply for any of the two links - and if you are granted - you can access both. This means that you can apply for any of the two links - and if you are granted - you can access both.
Please log in to your Hugging Face Account with your organization email to request access. Please log in to your Hugging Face Account with your organization email to request access.
After obtaining the weights, place them into `checkpoints/`. After obtaining the weights, place them into `checkpoints/`.
Next, start the demo using Next, start the demo using
@@ -166,6 +190,7 @@ not the same as in previous Stable Diffusion 1.x/2.x versions.
To run the script you need to either have a working installation as above or To run the script you need to either have a working installation as above or
try an _experimental_ import using only a minimal amount of packages: try an _experimental_ import using only a minimal amount of packages:
```bash ```bash
python -m venv .detect python -m venv .detect
source .detect/bin/activate source .detect/bin/activate
@@ -177,6 +202,7 @@ pip install --no-deps invisible-watermark
To run the script you need to have a working installation as above. The script To run the script you need to have a working installation as above. The script
is then useable in the following ways (don't forget to activate your is then useable in the following ways (don't forget to activate your
virtual environment beforehand, e.g. `source .pt1/bin/activate`): virtual environment beforehand, e.g. `source .pt1/bin/activate`):
```bash ```bash
# test a single file # test a single file
python scripts/demo/detect.py <your filename here> python scripts/demo/detect.py <your filename here>
@@ -203,11 +229,21 @@ run
python main.py --base configs/example_training/toy/mnist_cond.yaml python main.py --base configs/example_training/toy/mnist_cond.yaml
``` ```
**NOTE 1:** Using the non-toy-dataset configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml` and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the used dataset (which is expected to stored in tar-file in the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search for comments containing `USER:` in the respective config. **NOTE 1:** Using the non-toy-dataset
configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml`
and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the
used dataset (which is expected to stored in tar-file in
the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search
for comments containing `USER:` in the respective config.
**NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`, only `pytorch1.13` is supported. **NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for
autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`,
only `pytorch1.13` is supported.
**NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done for the provided text-to-image configs. **NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires
retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing
the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done
for the provided text-to-image configs.
### Building New Diffusion Models ### Building New Diffusion Models
@@ -216,7 +252,8 @@ python main.py --base configs/example_training/toy/mnist_cond.yaml
The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of
different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model. different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model.
All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free
guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for text-conditioning or `cls` for class-conditioning. guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for
text-conditioning or `cls` for class-conditioning.
When computing conditionings, the embedder will get `batch[input_key]` as input. When computing conditionings, the embedder will get `batch[input_key]` as input.
We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated
appropriately. appropriately.
@@ -229,7 +266,8 @@ enough as we plan to experiment with transformer-based diffusion backbones.
#### Loss #### Loss
The loss is configured through `loss_config`. For standard diffusion model training, you will have to set `sigma_sampler_config`. The loss is configured through `loss_config`. For standard diffusion model training, you will have to
set `sigma_sampler_config`.
#### Sampler config #### Sampler config
@@ -239,8 +277,9 @@ guidance.
### Dataset Handling ### Dataset Handling
For large scale training we recommend using the data pipelines from
For large scale training we recommend using the data pipelines from our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement and automatically included when following the steps from the [Installation section](#installation). our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement
and automatically included when following the steps from the [Installation section](#installation).
Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
data keys/values, data keys/values,
e.g., e.g.,

BIN
assets/test_image.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 482 KiB

BIN
assets/tile.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 MiB

View File

@@ -29,25 +29,14 @@ model:
in_channels: 3 in_channels: 3
out_ch: 3 out_ch: 3
ch: 128 ch: 128
ch_mult: [ 1, 2, 4 ] ch_mult: [1, 2, 4]
num_res_blocks: 4 num_res_blocks: 4
attn_resolutions: [ ] attn_resolutions: []
dropout: 0.0 dropout: 0.0
decoder_config: decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder target: sgm.modules.diffusionmodules.model.Decoder
params: params: ${model.params.encoder_config.params}
attn_type: none
double_z: False
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4 ]
num_res_blocks: 4
attn_resolutions: [ ]
dropout: 0.0
data: data:
target: sgm.data.dataset.StableDataModuleFromConfig target: sgm.data.dataset.StableDataModuleFromConfig
@@ -55,18 +44,18 @@ data:
train: train:
datapipeline: datapipeline:
urls: urls:
- "DATA-PATH" - DATA-PATH
pipeline_config: pipeline_config:
shardshuffle: 10000 shardshuffle: 10000
sample_shuffle: 10000 sample_shuffle: 10000
decoders: decoders:
- "pil" - pil
postprocessors: postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms - target: sdata.mappers.TorchVisionImageTransforms
params: params:
key: 'jpg' key: jpg
transforms: transforms:
- target: torchvision.transforms.Resize - target: torchvision.transforms.Resize
params: params:

View File

@@ -0,0 +1,105 @@
model:
base_learning_rate: 4.5e-6
target: sgm.models.autoencoder.AutoencodingEngine
params:
input_key: jpg
monitor: val/loss/rec
disc_start_iter: 0
encoder_config:
target: sgm.modules.diffusionmodules.model.Encoder
params:
attn_type: vanilla-xformers
double_z: true
z_channels: 8
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params: ${model.params.encoder_config.params}
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
loss_config:
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
params:
perceptual_weight: 0.25
disc_start: 20001
disc_weight: 0.5
learn_logvar: True
regularization_weights:
kl_loss: 1.0
data:
target: sgm.data.dataset.StableDataModuleFromConfig
params:
train:
datapipeline:
urls:
- DATA-PATH
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000
decoders:
- pil
postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: jpg
transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.ToTensor
- target: sdata.mappers.Rescaler
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
params:
h_key: height
w_key: width
loader:
batch_size: 8
num_workers: 4
lightning:
strategy:
target: pytorch_lightning.strategies.DDPStrategy
params:
find_unused_parameters: True
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 50000
image_logger:
target: main.ImageLogger
params:
enable_autocast: False
batch_frequency: 1000
max_images: 8
increase_log_steps: True
trainer:
devices: 0,
limit_val_batches: 50
benchmark: True
accumulate_grad_batches: 1
val_check_interval: 10000

View File

@@ -21,8 +21,6 @@ model:
params: params:
num_idx: 1000 num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config: scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config: discretization_config:
@@ -32,7 +30,6 @@ model:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True use_checkpoint: True
use_fp16: True
in_channels: 4 in_channels: 4
out_channels: 4 out_channels: 4
model_channels: 256 model_channels: 256
@@ -42,7 +39,6 @@ model:
num_head_channels: 64 num_head_channels: 64
num_classes: sequential num_classes: sequential
adm_in_channels: 1024 adm_in_channels: 1024
use_spatial_transformer: true
transformer_depth: 1 transformer_depth: 1
context_dim: 1024 context_dim: 1024
spatial_transformer_attn_type: softmax-xformers spatial_transformer_attn_type: softmax-xformers
@@ -51,32 +47,31 @@ model:
target: sgm.modules.GeneralConditioner target: sgm.modules.GeneralConditioner
params: params:
emb_models: emb_models:
# crossattn cond
- is_trainable: True - is_trainable: True
input_key: cls input_key: cls
ucg_rate: 0.2 ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder target: sgm.modules.encoders.modules.ClassEmbedder
params: params:
add_sequence_dim: True # will be used through crossattn then add_sequence_dim: True
embed_dim: 1024 embed_dim: 1024
n_classes: 1000 n_classes: 1000
# vector cond
- is_trainable: False - is_trainable: False
ucg_rate: 0.2 ucg_rate: 0.2
input_key: original_size_as_tuple input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params: params:
outdim: 256 # multiplied by two outdim: 256
# vector cond
- is_trainable: False - is_trainable: False
input_key: crop_coords_top_left input_key: crop_coords_top_left
ucg_rate: 0.2 ucg_rate: 0.2
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params: params:
outdim: 256 # multiplied by two outdim: 256
first_stage_config: first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper target: sgm.models.autoencoder.AutoencoderKL
params: params:
ckpt_path: CKPT_PATH ckpt_path: CKPT_PATH
embed_dim: 4 embed_dim: 4
@@ -99,6 +94,8 @@ model:
loss_fn_config: loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params: params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
sigma_sampler_config: sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params: params:
@@ -127,18 +124,18 @@ data:
datapipeline: datapipeline:
urls: urls:
# USER: adapt this path the root of your custom dataset # USER: adapt this path the root of your custom dataset
- "DATA_PATH" - DATA_PATH
pipeline_config: pipeline_config:
shardshuffle: 10000 shardshuffle: 10000
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
decoders: decoders:
- "pil" - pil
postprocessors: postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms - target: sdata.mappers.TorchVisionImageTransforms
params: params:
key: 'jpg' # USER: you might wanna adapt this for your custom dataset key: jpg # USER: you might wanna adapt this for your custom dataset
transforms: transforms:
- target: torchvision.transforms.Resize - target: torchvision.transforms.Resize
params: params:

View File

@@ -5,10 +5,6 @@ model:
denoiser_config: denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser target: sgm.modules.diffusionmodules.denoiser.Denoiser
params: params:
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
params:
sigma_data: 1.0
scaling_config: scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
params: params:
@@ -17,7 +13,6 @@ model:
network_config: network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True
in_channels: 3 in_channels: 3
out_channels: 3 out_channels: 3
model_channels: 32 model_channels: 32
@@ -46,6 +41,10 @@ model:
loss_fn_config: loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params: params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
params:
sigma_data: 1.0
sigma_sampler_config: sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling

View File

@@ -5,10 +5,6 @@ model:
denoiser_config: denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser target: sgm.modules.diffusionmodules.denoiser.Denoiser
params: params:
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
params:
sigma_data: 1.0
scaling_config: scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
params: params:
@@ -17,7 +13,6 @@ model:
network_config: network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True
in_channels: 1 in_channels: 1
out_channels: 1 out_channels: 1
model_channels: 32 model_channels: 32
@@ -32,6 +27,10 @@ model:
loss_fn_config: loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params: params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
params:
sigma_data: 1.0
sigma_sampler_config: sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling

View File

@@ -5,10 +5,6 @@ model:
denoiser_config: denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser target: sgm.modules.diffusionmodules.denoiser.Denoiser
params: params:
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
params:
sigma_data: 1.0
scaling_config: scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
params: params:
@@ -17,13 +13,12 @@ model:
network_config: network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True
in_channels: 1 in_channels: 1
out_channels: 1 out_channels: 1
model_channels: 32 model_channels: 32
attention_resolutions: [ ] attention_resolutions: []
num_res_blocks: 4 num_res_blocks: 4
channel_mult: [ 1, 2, 2 ] channel_mult: [1, 2, 2]
num_head_channels: 32 num_head_channels: 32
num_classes: sequential num_classes: sequential
adm_in_channels: 128 adm_in_channels: 128
@@ -33,7 +28,7 @@ model:
params: params:
emb_models: emb_models:
- is_trainable: True - is_trainable: True
input_key: "cls" input_key: cls
ucg_rate: 0.2 ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder target: sgm.modules.encoders.modules.ClassEmbedder
params: params:
@@ -46,6 +41,10 @@ model:
loss_fn_config: loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params: params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
params:
sigma_data: 1.0
sigma_sampler_config: sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling

View File

@@ -7,8 +7,6 @@ model:
params: params:
num_idx: 1000 num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
scaling_config: scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
discretization_config: discretization_config:
@@ -17,13 +15,12 @@ model:
network_config: network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True
in_channels: 1 in_channels: 1
out_channels: 1 out_channels: 1
model_channels: 32 model_channels: 32
attention_resolutions: [ ] attention_resolutions: []
num_res_blocks: 4 num_res_blocks: 4
channel_mult: [ 1, 2, 2 ] channel_mult: [1, 2, 2]
num_head_channels: 32 num_head_channels: 32
num_classes: sequential num_classes: sequential
adm_in_channels: 128 adm_in_channels: 128
@@ -33,7 +30,7 @@ model:
params: params:
emb_models: emb_models:
- is_trainable: True - is_trainable: True
input_key: "cls" input_key: cls
ucg_rate: 0.2 ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder target: sgm.modules.encoders.modules.ClassEmbedder
params: params:
@@ -46,6 +43,8 @@ model:
loss_fn_config: loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params: params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
sigma_sampler_config: sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params: params:

View File

@@ -5,10 +5,6 @@ model:
denoiser_config: denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser target: sgm.modules.diffusionmodules.denoiser.Denoiser
params: params:
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
params:
sigma_data: 1.0
scaling_config: scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
params: params:
@@ -17,7 +13,6 @@ model:
network_config: network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True
in_channels: 1 in_channels: 1
out_channels: 1 out_channels: 1
model_channels: 32 model_channels: 32
@@ -25,7 +20,7 @@ model:
num_res_blocks: 4 num_res_blocks: 4
channel_mult: [1, 2, 2] channel_mult: [1, 2, 2]
num_head_channels: 32 num_head_channels: 32
num_classes: "sequential" num_classes: sequential
adm_in_channels: 128 adm_in_channels: 128
conditioner_config: conditioner_config:
@@ -33,7 +28,7 @@ model:
params: params:
emb_models: emb_models:
- is_trainable: True - is_trainable: True
input_key: "cls" input_key: cls
ucg_rate: 0.2 ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder target: sgm.modules.encoders.modules.ClassEmbedder
params: params:
@@ -46,6 +41,11 @@ model:
loss_fn_config: loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params: params:
loss_type: l1
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
params:
sigma_data: 1.0
sigma_sampler_config: sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
@@ -62,11 +62,6 @@ model:
params: params:
scale: 3.0 scale: 3.0
loss_config:
target: sgm.modules.diffusionmodules.StandardDiffusionLoss
params:
type: l1
data: data:
target: sgm.data.mnist.MNISTLoader target: sgm.data.mnist.MNISTLoader
params: params:

View File

@@ -7,10 +7,6 @@ model:
denoiser_config: denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser target: sgm.modules.diffusionmodules.denoiser.Denoiser
params: params:
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
params:
sigma_data: 1.0
scaling_config: scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
params: params:
@@ -19,7 +15,6 @@ model:
network_config: network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True
in_channels: 1 in_channels: 1
out_channels: 1 out_channels: 1
model_channels: 32 model_channels: 32
@@ -48,6 +43,10 @@ model:
loss_fn_config: loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params: params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
params:
sigma_data: 1.0
sigma_sampler_config: sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling

View File

@@ -10,19 +10,17 @@ model:
scheduler_config: scheduler_config:
target: sgm.lr_scheduler.LambdaLinearScheduler target: sgm.lr_scheduler.LambdaLinearScheduler
params: params:
warm_up_steps: [ 10000 ] warm_up_steps: [10000]
cycle_lengths: [ 10000000000000 ] cycle_lengths: [10000000000000]
f_start: [ 1.e-6 ] f_start: [1.e-6]
f_max: [ 1. ] f_max: [1.]
f_min: [ 1. ] f_min: [1.]
denoiser_config: denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params: params:
num_idx: 1000 num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config: scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config: discretization_config:
@@ -32,18 +30,16 @@ model:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True use_checkpoint: True
use_fp16: True
in_channels: 4 in_channels: 4
out_channels: 4 out_channels: 4
model_channels: 320 model_channels: 320
attention_resolutions: [ 1, 2, 4 ] attention_resolutions: [1, 2, 4]
num_res_blocks: 2 num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ] channel_mult: [1, 2, 4, 4]
num_head_channels: 64 num_head_channels: 64
num_classes: sequential num_classes: sequential
adm_in_channels: 1792 adm_in_channels: 1792
num_heads: 1 num_heads: 1
use_spatial_transformer: true
transformer_depth: 1 transformer_depth: 1
context_dim: 768 context_dim: 768
spatial_transformer_attn_type: softmax-xformers spatial_transformer_attn_type: softmax-xformers
@@ -52,7 +48,6 @@ model:
target: sgm.modules.GeneralConditioner target: sgm.modules.GeneralConditioner
params: params:
emb_models: emb_models:
# crossattn cond
- is_trainable: True - is_trainable: True
input_key: txt input_key: txt
ucg_rate: 0.1 ucg_rate: 0.1
@@ -60,23 +55,23 @@ model:
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params: params:
always_return_pooled: True always_return_pooled: True
# vector cond
- is_trainable: False - is_trainable: False
ucg_rate: 0.1 ucg_rate: 0.1
input_key: original_size_as_tuple input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params: params:
outdim: 256 # multiplied by two outdim: 256
# vector cond
- is_trainable: False - is_trainable: False
input_key: crop_coords_top_left input_key: crop_coords_top_left
ucg_rate: 0.1 ucg_rate: 0.1
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params: params:
outdim: 256 # multiplied by two outdim: 256
first_stage_config: first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper target: sgm.models.autoencoder.AutoencoderKL
params: params:
ckpt_path: CKPT_PATH ckpt_path: CKPT_PATH
embed_dim: 4 embed_dim: 4
@@ -99,6 +94,8 @@ model:
loss_fn_config: loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params: params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
sigma_sampler_config: sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params: params:
@@ -127,18 +124,18 @@ data:
datapipeline: datapipeline:
urls: urls:
# USER: adapt this path the root of your custom dataset # USER: adapt this path the root of your custom dataset
- "DATA_PATH" - DATA_PATH
pipeline_config: pipeline_config:
shardshuffle: 10000 shardshuffle: 10000
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
decoders: decoders:
- "pil" - pil
postprocessors: postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms - target: sdata.mappers.TorchVisionImageTransforms
params: params:
key: 'jpg' # USER: you might wanna adapt this for your custom dataset key: jpg # USER: you might wanna adapt this for your custom dataset
transforms: transforms:
- target: torchvision.transforms.Resize - target: torchvision.transforms.Resize
params: params:

View File

@@ -10,19 +10,17 @@ model:
scheduler_config: scheduler_config:
target: sgm.lr_scheduler.LambdaLinearScheduler target: sgm.lr_scheduler.LambdaLinearScheduler
params: params:
warm_up_steps: [ 10000 ] warm_up_steps: [10000]
cycle_lengths: [ 10000000000000 ] cycle_lengths: [10000000000000]
f_start: [ 1.e-6 ] f_start: [1.e-6]
f_max: [ 1. ] f_max: [1.]
f_min: [ 1. ] f_min: [1.]
denoiser_config: denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params: params:
num_idx: 1000 num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config: scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config: discretization_config:
@@ -32,18 +30,16 @@ model:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True use_checkpoint: True
use_fp16: True
in_channels: 4 in_channels: 4
out_channels: 4 out_channels: 4
model_channels: 320 model_channels: 320
attention_resolutions: [ 1, 2, 4 ] attention_resolutions: [1, 2, 4]
num_res_blocks: 2 num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ] channel_mult: [1, 2, 4, 4]
num_head_channels: 64 num_head_channels: 64
num_classes: sequential num_classes: sequential
adm_in_channels: 1792 adm_in_channels: 1792
num_heads: 1 num_heads: 1
use_spatial_transformer: true
transformer_depth: 1 transformer_depth: 1
context_dim: 768 context_dim: 768
spatial_transformer_attn_type: softmax-xformers spatial_transformer_attn_type: softmax-xformers
@@ -52,30 +48,30 @@ model:
target: sgm.modules.GeneralConditioner target: sgm.modules.GeneralConditioner
params: params:
emb_models: emb_models:
# crossattn cond
- is_trainable: True - is_trainable: True
input_key: txt input_key: txt
ucg_rate: 0.1 ucg_rate: 0.1
legacy_ucg_value: ""
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params: params:
always_return_pooled: True always_return_pooled: True
# vector cond
- is_trainable: False - is_trainable: False
ucg_rate: 0.1 ucg_rate: 0.1
input_key: original_size_as_tuple input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params: params:
outdim: 256 # multiplied by two outdim: 256
# vector cond
- is_trainable: False - is_trainable: False
input_key: crop_coords_top_left input_key: crop_coords_top_left
ucg_rate: 0.1 ucg_rate: 0.1
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params: params:
outdim: 256 # multiplied by two outdim: 256
first_stage_config: first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper target: sgm.models.autoencoder.AutoencoderKL
params: params:
ckpt_path: CKPT_PATH ckpt_path: CKPT_PATH
embed_dim: 4 embed_dim: 4
@@ -88,9 +84,9 @@ model:
in_channels: 3 in_channels: 3
out_ch: 3 out_ch: 3
ch: 128 ch: 128
ch_mult: [ 1, 2, 4, 4 ] ch_mult: [1, 2, 4, 4]
num_res_blocks: 2 num_res_blocks: 2
attn_resolutions: [ ] attn_resolutions: []
dropout: 0.0 dropout: 0.0
lossconfig: lossconfig:
target: torch.nn.Identity target: torch.nn.Identity
@@ -98,6 +94,8 @@ model:
loss_fn_config: loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params: params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
sigma_sampler_config: sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params: params:
@@ -126,19 +124,19 @@ data:
datapipeline: datapipeline:
urls: urls:
# USER: adapt this path the root of your custom dataset # USER: adapt this path the root of your custom dataset
- "DATA_PATH" - DATA_PATH
pipeline_config: pipeline_config:
shardshuffle: 10000 shardshuffle: 10000
sample_shuffle: 10000 sample_shuffle: 10000
decoders: decoders:
- "pil" - pil
postprocessors: postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms - target: sdata.mappers.TorchVisionImageTransforms
params: params:
key: 'jpg' # USER: you might wanna adapt this for your custom dataset key: jpg # USER: you might wanna adapt this for your custom dataset
transforms: transforms:
- target: torchvision.transforms.Resize - target: torchvision.transforms.Resize
params: params:

View File

@@ -9,8 +9,6 @@ model:
params: params:
num_idx: 1000 num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config: scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config: discretization_config:
@@ -20,7 +18,6 @@ model:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True use_checkpoint: True
use_fp16: True
in_channels: 4 in_channels: 4
out_channels: 4 out_channels: 4
model_channels: 320 model_channels: 320
@@ -28,17 +25,14 @@ model:
num_res_blocks: 2 num_res_blocks: 2
channel_mult: [1, 2, 4, 4] channel_mult: [1, 2, 4, 4]
num_head_channels: 64 num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True use_linear_in_transformer: True
transformer_depth: 1 transformer_depth: 1
context_dim: 1024 context_dim: 1024
legacy: False
conditioner_config: conditioner_config:
target: sgm.modules.GeneralConditioner target: sgm.modules.GeneralConditioner
params: params:
emb_models: emb_models:
# crossattn cond
- is_trainable: False - is_trainable: False
input_key: txt input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
@@ -47,7 +41,7 @@ model:
layer: penultimate layer: penultimate
first_stage_config: first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper target: sgm.models.autoencoder.AutoencoderKL
params: params:
embed_dim: 4 embed_dim: 4
monitor: val/rec_loss monitor: val/rec_loss

View File

@@ -9,8 +9,6 @@ model:
params: params:
num_idx: 1000 num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.VWeighting
scaling_config: scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
discretization_config: discretization_config:
@@ -20,7 +18,6 @@ model:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True use_checkpoint: True
use_fp16: True
in_channels: 4 in_channels: 4
out_channels: 4 out_channels: 4
model_channels: 320 model_channels: 320
@@ -28,17 +25,14 @@ model:
num_res_blocks: 2 num_res_blocks: 2
channel_mult: [1, 2, 4, 4] channel_mult: [1, 2, 4, 4]
num_head_channels: 64 num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True use_linear_in_transformer: True
transformer_depth: 1 transformer_depth: 1
context_dim: 1024 context_dim: 1024
legacy: False
conditioner_config: conditioner_config:
target: sgm.modules.GeneralConditioner target: sgm.modules.GeneralConditioner
params: params:
emb_models: emb_models:
# crossattn cond
- is_trainable: False - is_trainable: False
input_key: txt input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
@@ -47,7 +41,7 @@ model:
layer: penultimate layer: penultimate
first_stage_config: first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper target: sgm.models.autoencoder.AutoencoderKL
params: params:
embed_dim: 4 embed_dim: 4
monitor: val/rec_loss monitor: val/rec_loss

View File

@@ -9,8 +9,6 @@ model:
params: params:
num_idx: 1000 num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config: scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config: discretization_config:
@@ -29,25 +27,22 @@ model:
num_res_blocks: 2 num_res_blocks: 2
channel_mult: [1, 2, 4] channel_mult: [1, 2, 4]
num_head_channels: 64 num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 transformer_depth: [1, 2, 10]
context_dim: 2048 context_dim: 2048
spatial_transformer_attn_type: softmax-xformers spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config: conditioner_config:
target: sgm.modules.GeneralConditioner target: sgm.modules.GeneralConditioner
params: params:
emb_models: emb_models:
# crossattn cond
- is_trainable: False - is_trainable: False
input_key: txt input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params: params:
layer: hidden layer: hidden
layer_idx: 11 layer_idx: 11
# crossattn and vector cond
- is_trainable: False - is_trainable: False
input_key: txt input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
@@ -58,27 +53,27 @@ model:
layer: penultimate layer: penultimate
always_return_pooled: True always_return_pooled: True
legacy: False legacy: False
# vector cond
- is_trainable: False - is_trainable: False
input_key: original_size_as_tuple input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params: params:
outdim: 256 # multiplied by two outdim: 256
# vector cond
- is_trainable: False - is_trainable: False
input_key: crop_coords_top_left input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params: params:
outdim: 256 # multiplied by two outdim: 256
# vector cond
- is_trainable: False - is_trainable: False
input_key: target_size_as_tuple input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params: params:
outdim: 256 # multiplied by two outdim: 256
first_stage_config: first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper target: sgm.models.autoencoder.AutoencoderKL
params: params:
embed_dim: 4 embed_dim: 4
monitor: val/rec_loss monitor: val/rec_loss

View File

@@ -9,8 +9,6 @@ model:
params: params:
num_idx: 1000 num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config: scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config: discretization_config:
@@ -29,18 +27,15 @@ model:
num_res_blocks: 2 num_res_blocks: 2
channel_mult: [1, 2, 4, 4] channel_mult: [1, 2, 4, 4]
num_head_channels: 64 num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True use_linear_in_transformer: True
transformer_depth: 4 transformer_depth: 4
context_dim: [1280, 1280, 1280, 1280] # 1280 context_dim: [1280, 1280, 1280, 1280]
spatial_transformer_attn_type: softmax-xformers spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config: conditioner_config:
target: sgm.modules.GeneralConditioner target: sgm.modules.GeneralConditioner
params: params:
emb_models: emb_models:
# crossattn and vector cond
- is_trainable: False - is_trainable: False
input_key: txt input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
@@ -51,27 +46,27 @@ model:
freeze: True freeze: True
layer: penultimate layer: penultimate
always_return_pooled: True always_return_pooled: True
# vector cond
- is_trainable: False - is_trainable: False
input_key: original_size_as_tuple input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params: params:
outdim: 256 # multiplied by two outdim: 256
# vector cond
- is_trainable: False - is_trainable: False
input_key: crop_coords_top_left input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params: params:
outdim: 256 # multiplied by two outdim: 256
# vector cond
- is_trainable: False - is_trainable: False
input_key: aesthetic_score input_key: aesthetic_score
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params: params:
outdim: 256 # multiplied by one outdim: 256
first_stage_config: first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper target: sgm.models.autoencoder.AutoencoderKL
params: params:
embed_dim: 4 embed_dim: 4
monitor: val/rec_loss monitor: val/rec_loss

131
configs/inference/svd.yaml Normal file
View File

@@ -0,0 +1,131 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
network_config:
target: sgm.modules.diffusionmodules.video_model.VideoUNet
params:
adm_in_channels: 768
num_classes: sequential
use_checkpoint: True
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2, 1]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
extra_ff_mix_layer: True
use_spatial_context: True
merge_strategy: learned_with_images
video_kernel_size: [3, 1, 1]
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: cond_frames_without_noise
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
params:
n_cond_frames: 1
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True
- input_key: fps_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: motion_bucket_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: cond_frames
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
disable_encoder_autocast: True
n_cond_frames: 1
n_copies: 1
is_ae: True
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
- input_key: cond_aug
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencodingEngine
params:
loss_config:
target: torch.nn.Identity
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: sgm.modules.diffusionmodules.model.Encoder
params:
attn_type: vanilla
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
decoder_config:
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
params:
attn_type: vanilla
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
video_kernel_size: [3, 1, 1]

View File

@@ -0,0 +1,114 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
network_config:
target: sgm.modules.diffusionmodules.video_model.VideoUNet
params:
adm_in_channels: 768
num_classes: sequential
use_checkpoint: True
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2, 1]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
extra_ff_mix_layer: True
use_spatial_context: True
merge_strategy: learned_with_images
video_kernel_size: [3, 1, 1]
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: cond_frames_without_noise
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
params:
n_cond_frames: 1
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True
- input_key: fps_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: motion_bucket_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: cond_frames
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
disable_encoder_autocast: True
n_cond_frames: 1
n_copies: 1
is_ae: True
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
- input_key: cond_aug
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@@ -0,0 +1,31 @@
STABLE VIDEO DIFFUSION NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT
Dated: November 21, 2023
“AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
"Agreement" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein.
"Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Models output. For clarity, Derivative Works do not include the output of any Model.
“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
"Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
"Stability AI" or "we" means Stability AI Ltd.
"Software" means, collectively, Stability AIs proprietary StableCode made available under this Agreement.
“Software Products” means Software and Documentation.
By using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement.
License Rights and Redistribution.
Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AIs intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create Derivative Works of the Software Products for purposes other than commercial or production use.
b. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "Stable Video Diffusion is licensed under the Stable Video Diffusion Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS.
3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
3. Intellectual Property.
a. No trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products.
Subject to Stability AIs ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works.
If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement.
4. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Software Products. Sections 2-4 shall survive the termination of this Agreement.

View File

@@ -0,0 +1,59 @@
import torch
from sgm.modules.diffusionmodules.discretizer import Discretization
class Img2ImgDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization: Discretization, strength: float = 1.0):
self.discretization = discretization
self.strength = strength
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
class Txt2NoisyDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
"""
def __init__(
self, discretization: Discretization, strength: float = 0.0, original_steps=None
):
self.discretization = discretization
self.strength = strength
self.original_steps = original_steps
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
if self.original_steps is None:
steps = len(sigmas)
else:
steps = self.original_steps + 1
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
sigmas = sigmas[prune_index:]
print("prune index:", prune_index)
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas

View File

@@ -253,7 +253,10 @@ if __name__ == "__main__":
st.title("Stable Diffusion") st.title("Stable Diffusion")
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
version_dict = VERSION2SPECS[version] version_dict = VERSION2SPECS[version]
if st.checkbox("Load Model"):
mode = st.radio("Mode", ("txt2img", "img2img"), 0) mode = st.radio("Mode", ("txt2img", "img2img"), 0)
else:
mode = "skip"
st.write("__________________________") st.write("__________________________")
set_lowvram_mode(st.checkbox("Low vram mode", True)) set_lowvram_mode(st.checkbox("Low vram mode", True))
@@ -269,6 +272,7 @@ if __name__ == "__main__":
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version)) save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
if mode != "skip":
state = init_st(version_dict, load_filter=True) state = init_st(version_dict, load_filter=True)
if state["msg"]: if state["msg"]:
st.info(state["msg"]) st.info(state["msg"])
@@ -333,6 +337,8 @@ if __name__ == "__main__":
filter=state.get("filter"), filter=state.get("filter"),
stage2strength=stage2strength, stage2strength=stage2strength,
) )
elif mode == "skip":
out = None
else: else:
raise ValueError(f"unknown mode {mode}") raise ValueError(f"unknown mode {mode}")
if isinstance(out, (tuple, list)): if isinstance(out, (tuple, list)):

View File

@@ -1,10 +1,15 @@
import copy
import math import math
import os import os
from typing import List, Union from glob import glob
from typing import Dict, List, Optional, Tuple, Union
import cv2
import numpy as np import numpy as np
import streamlit as st import streamlit as st
import torch import torch
import torch.nn as nn
import torchvision.transforms as TT
from einops import rearrange, repeat from einops import rearrange, repeat
from imwatermark import WatermarkEncoder from imwatermark import WatermarkEncoder
from omegaconf import ListConfig, OmegaConf from omegaconf import ListConfig, OmegaConf
@@ -12,63 +17,22 @@ from PIL import Image
from safetensors.torch import load_file as load_safetensors from safetensors.torch import load_file as load_safetensors
from torch import autocast from torch import autocast
from torchvision import transforms from torchvision import transforms
from torchvision.utils import make_grid from torchvision.utils import make_grid, save_image
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from scripts.demo.discretization import (Img2ImgDiscretizationWrapper,
from sgm.modules.diffusionmodules.sampling import ( Txt2NoisyDiscretizationWrapper)
DPMPP2MSampler, from scripts.util.detection.nsfw_and_watermark_dectection import \
DeepFloydDataFiltering
from sgm.inference.helpers import embed_watermark
from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider,
VanillaCFG)
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
DPMPP2SAncestralSampler, DPMPP2SAncestralSampler,
EulerAncestralSampler, EulerAncestralSampler,
EulerEDMSampler, EulerEDMSampler,
HeunEDMSampler, HeunEDMSampler,
LinearMultistepSampler, LinearMultistepSampler)
) from sgm.util import append_dims, default, instantiate_from_config
from sgm.util import append_dims, instantiate_from_config
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor):
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, C, H, W) in range [0, 1]
Returns:
same as input but watermarked
"""
# watermarking libary expects input as cv2 BGR format
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
).to(image.device)
image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
return image
# A fixed 48-bit message that was choosen at random
# WATERMARK_MESSAGE = 0xB3EC907BB19E
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
@st.cache_resource() @st.cache_resource()
@@ -164,11 +128,12 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
for key in keys: for key in keys:
if key == "txt": if key == "txt":
if prompt is None: if prompt is None:
prompt = st.text_input( prompt = "A professional photograph of an astronaut riding a pig"
"Prompt", "A professional photograph of an astronaut riding a pig"
)
if negative_prompt is None: if negative_prompt is None:
negative_prompt = st.text_input("Negative prompt", "") negative_prompt = ""
prompt = st.text_input("Prompt", prompt)
negative_prompt = st.text_input("Negative prompt", negative_prompt)
value_dict["prompt"] = prompt value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt value_dict["negative_prompt"] = negative_prompt
@@ -203,13 +168,35 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
value_dict["target_width"] = init_dict["target_width"] value_dict["target_width"] = init_dict["target_width"]
value_dict["target_height"] = init_dict["target_height"] value_dict["target_height"] = init_dict["target_height"]
if key in ["fps_id", "fps"]:
fps = st.number_input("fps", value=6, min_value=1)
value_dict["fps"] = fps
value_dict["fps_id"] = fps - 1
if key == "motion_bucket_id":
mb_id = st.number_input("motion bucket id", 0, 511, value=127)
value_dict["motion_bucket_id"] = mb_id
if key == "pool_image":
st.text("Image for pool conditioning")
image = load_img(
key="pool_image_input",
size=224,
center_crop=True,
)
if image is None:
st.info("Need an image here")
image = torch.zeros(1, 3, 224, 224)
value_dict["pool_image"] = image
return value_dict return value_dict
def perform_save_locally(save_path, samples): def perform_save_locally(save_path, samples):
os.makedirs(os.path.join(save_path), exist_ok=True) os.makedirs(os.path.join(save_path), exist_ok=True)
base_count = len(os.listdir(os.path.join(save_path))) base_count = len(os.listdir(os.path.join(save_path)))
samples = embed_watemark(samples) samples = embed_watermark(samples)
for sample in samples: for sample in samples:
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save( Image.fromarray(sample.astype(np.uint8)).save(
@@ -228,95 +215,99 @@ def init_save_locally(_dir, init_value: bool = False):
return save_locally, save_path return save_locally, save_path
class Img2ImgDiscretizationWrapper: def get_guider(options, key):
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 1.0):
self.discretization = discretization
self.strength = strength
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
class Txt2NoisyDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
self.discretization = discretization
self.strength = strength
self.original_steps = original_steps
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
if self.original_steps is None:
steps = len(sigmas)
else:
steps = self.original_steps + 1
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
sigmas = sigmas[prune_index:]
print("prune index:", prune_index)
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def get_guider(key):
guider = st.sidebar.selectbox( guider = st.sidebar.selectbox(
f"Discretization #{key}", f"Discretization #{key}",
[ [
"VanillaCFG", "VanillaCFG",
"IdentityGuider", "IdentityGuider",
"LinearPredictionGuider",
], ],
options.get("guider", 0),
) )
additional_guider_kwargs = options.pop("additional_guider_kwargs", {})
if guider == "IdentityGuider": if guider == "IdentityGuider":
guider_config = { guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
} }
elif guider == "VanillaCFG": elif guider == "VanillaCFG":
scale_schedule = st.sidebar.selectbox(
f"Scale schedule #{key}",
["Identity", "Oscillating"],
)
if scale_schedule == "Identity":
scale = st.number_input( scale = st.number_input(
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0 f"cfg-scale #{key}",
value=options.get("cfg", 5.0),
min_value=0.0,
) )
thresholder = st.sidebar.selectbox( scale_schedule_config = {
f"Thresholder #{key}", "target": "sgm.modules.diffusionmodules.guiders.IdentitySchedule",
[ "params": {"scale": scale},
"None", }
],
elif scale_schedule == "Oscillating":
small_scale = st.number_input(
f"small cfg-scale #{key}",
value=4.0,
min_value=0.0,
) )
if thresholder == "None": large_scale = st.number_input(
dyn_thresh_config = { f"large cfg-scale #{key}",
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" value=16.0,
min_value=0.0,
)
sigma_cutoff = st.number_input(
f"sigma cutoff #{key}",
value=1.0,
min_value=0.0,
)
scale_schedule_config = {
"target": "sgm.modules.diffusionmodules.guiders.OscillatingSchedule",
"params": {
"small_scale": small_scale,
"large_scale": large_scale,
"sigma_cutoff": sigma_cutoff,
},
} }
else: else:
raise NotImplementedError raise NotImplementedError
guider_config = { guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, "params": {
"scale_schedule_config": scale_schedule_config,
**additional_guider_kwargs,
},
}
elif guider == "LinearPredictionGuider":
max_scale = st.number_input(
f"max-cfg-scale #{key}",
value=options.get("cfg", 1.5),
min_value=1.0,
)
min_scale = st.number_input(
f"min guidance scale",
value=options.get("min_cfg", 1.0),
min_value=1.0,
max_value=10.0,
)
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider",
"params": {
"max_scale": max_scale,
"min_scale": min_scale,
"num_frames": options["num_frames"],
**additional_guider_kwargs,
},
} }
else: else:
raise NotImplementedError raise NotImplementedError
@@ -325,18 +316,21 @@ def get_guider(key):
def init_sampling( def init_sampling(
key=1, key=1,
img2img_strength=1.0, img2img_strength: Optional[float] = None,
specify_num_samples=True, specify_num_samples: bool = True,
stage2strength=None, stage2strength: Optional[float] = None,
options: Optional[Dict[str, int]] = None,
): ):
options = {} if options is None else options
num_rows, num_cols = 1, 1 num_rows, num_cols = 1, 1
if specify_num_samples: if specify_num_samples:
num_cols = st.number_input( num_cols = st.number_input(
f"num cols #{key}", value=2, min_value=1, max_value=10 f"num cols #{key}", value=num_cols, min_value=1, max_value=10
) )
steps = st.sidebar.number_input( steps = st.sidebar.number_input(
f"steps #{key}", value=40, min_value=1, max_value=1000 f"steps #{key}", value=options.get("num_steps", 40), min_value=1, max_value=1000
) )
sampler = st.sidebar.selectbox( sampler = st.sidebar.selectbox(
f"Sampler #{key}", f"Sampler #{key}",
@@ -348,7 +342,7 @@ def init_sampling(
"DPMPP2MSampler", "DPMPP2MSampler",
"LinearMultistepSampler", "LinearMultistepSampler",
], ],
0, options.get("sampler", 0),
) )
discretization = st.sidebar.selectbox( discretization = st.sidebar.selectbox(
f"Discretization #{key}", f"Discretization #{key}",
@@ -356,14 +350,15 @@ def init_sampling(
"LegacyDDPMDiscretization", "LegacyDDPMDiscretization",
"EDMDiscretization", "EDMDiscretization",
], ],
options.get("discretization", 0),
) )
discretization_config = get_discretization(discretization, key=key) discretization_config = get_discretization(discretization, options=options, key=key)
guider_config = get_guider(key=key) guider_config = get_guider(options=options, key=key)
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key) sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
if img2img_strength < 1.0: if img2img_strength is not None:
st.warning( st.warning(
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper" f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
) )
@@ -377,15 +372,19 @@ def init_sampling(
return sampler, num_rows, num_cols return sampler, num_rows, num_cols
def get_discretization(discretization, key=1): def get_discretization(discretization, options, key=1):
if discretization == "LegacyDDPMDiscretization": if discretization == "LegacyDDPMDiscretization":
discretization_config = { discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
} }
elif discretization == "EDMDiscretization": elif discretization == "EDMDiscretization":
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292 sigma_min = st.number_input(
sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146 f"sigma_min #{key}", value=options.get("sigma_min", 0.03)
rho = st.number_input(f"rho #{key}", value=3.0) ) # 0.0292
sigma_max = st.number_input(
f"sigma_max #{key}", value=options.get("sigma_max", 14.61)
) # 14.6146
rho = st.number_input(f"rho #{key}", value=options.get("rho", 3.0))
discretization_config = { discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
"params": { "params": {
@@ -474,8 +473,8 @@ def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1
return sampler return sampler
def get_interactive_image(key=None) -> Image.Image: def get_interactive_image() -> Image.Image:
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key) image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
if image is not None: if image is not None:
image = Image.open(image) image = Image.open(image)
if not image.mode == "RGB": if not image.mode == "RGB":
@@ -483,8 +482,12 @@ def get_interactive_image(key=None) -> Image.Image:
return image return image
def load_img(display=True, key=None): def load_img(
image = get_interactive_image(key=key) display: bool = True,
size: Union[None, int, Tuple[int, int]] = None,
center_crop: bool = False,
):
image = get_interactive_image()
if image is None: if image is None:
return None return None
if display: if display:
@@ -492,12 +495,15 @@ def load_img(display=True, key=None):
w, h = image.size w, h = image.size
print(f"loaded input image of size ({w}, {h})") print(f"loaded input image of size ({w}, {h})")
transform = transforms.Compose( transform = []
[ if size is not None:
transforms.ToTensor(), transform.append(transforms.Resize(size))
transforms.Lambda(lambda x: x * 2.0 - 1.0), if center_crop:
] transform.append(transforms.CenterCrop(size))
) transform.append(transforms.ToTensor())
transform.append(transforms.Lambda(lambda x: 2.0 * x - 1.0))
transform = transforms.Compose(transform)
img = transform(image)[None, ...] img = transform(image)[None, ...]
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}") st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
return img return img
@@ -518,15 +524,18 @@ def do_sample(
W, W,
C, C,
F, F,
force_uc_zero_embeddings: List = None, force_uc_zero_embeddings: Optional[List] = None,
force_cond_zero_embeddings: Optional[List] = None,
batch2model_input: List = None, batch2model_input: List = None,
return_latents=False, return_latents=False,
filter=None, filter=None,
T=None,
additional_batch_uc_fields=None,
decoding_t=None,
): ):
if force_uc_zero_embeddings is None: force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])
force_uc_zero_embeddings = [] batch2model_input = default(batch2model_input, [])
if batch2model_input is None: additional_batch_uc_fields = default(additional_batch_uc_fields, [])
batch2model_input = []
st.text("Sampling") st.text("Sampling")
@@ -535,24 +544,25 @@ def do_sample(
with torch.no_grad(): with torch.no_grad():
with precision_scope("cuda"): with precision_scope("cuda"):
with model.ema_scope(): with model.ema_scope():
if T is not None:
num_samples = [num_samples, T]
else:
num_samples = [num_samples] num_samples = [num_samples]
load_model(model.conditioner) load_model(model.conditioner)
batch, batch_uc = get_batch( batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner), get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict, value_dict,
num_samples, num_samples,
T=T,
additional_batch_uc_fields=additional_batch_uc_fields,
) )
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning( c, uc = model.conditioner.get_unconditional_conditioning(
batch, batch,
batch_uc=batch_uc, batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings, force_uc_zero_embeddings=force_uc_zero_embeddings,
force_cond_zero_embeddings=force_cond_zero_embeddings,
) )
unload_model(model.conditioner) unload_model(model.conditioner)
@@ -561,9 +571,28 @@ def do_sample(
c[k], uc[k] = map( c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
) )
if k in ["crossattn", "concat"] and T is not None:
uc[k] = repeat(uc[k], "b ... -> b t ...", t=T)
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=T)
c[k] = repeat(c[k], "b ... -> b t ...", t=T)
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=T)
additional_model_inputs = {} additional_model_inputs = {}
for k in batch2model_input: for k in batch2model_input:
if k == "image_only_indicator":
assert T is not None
if isinstance(
sampler.guider, (VanillaCFG, LinearPredictionGuider)
):
additional_model_inputs[k] = torch.zeros(
num_samples[0] * 2, num_samples[1]
).to("cuda")
else:
additional_model_inputs[k] = torch.zeros(num_samples).to(
"cuda"
)
else:
additional_model_inputs[k] = batch[k] additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F) shape = (math.prod(num_samples), C, H // F, W // F)
@@ -581,6 +610,9 @@ def do_sample(
unload_model(model.denoiser) unload_model(model.denoiser)
load_model(model.first_stage_model) load_model(model.first_stage_model)
model.en_and_decode_n_samples_a_time = (
decoding_t # Decode n frames at a time
)
samples_x = model.decode_first_stage(samples_z) samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_model(model.first_stage_model) unload_model(model.first_stage_model)
@@ -588,16 +620,32 @@ def do_sample(
if filter is not None: if filter is not None:
samples = filter(samples) samples = filter(samples)
if T is None:
grid = torch.stack([samples]) grid = torch.stack([samples])
grid = rearrange(grid, "n b c h w -> (n h) (b w) c") grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy()) outputs.image(grid.cpu().numpy())
else:
as_vids = rearrange(samples, "(b t) c h w -> b t c h w", t=T)
for i, vid in enumerate(as_vids):
grid = rearrange(make_grid(vid, nrow=4), "c h w -> h w c")
st.image(
grid.cpu().numpy(),
f"Sample #{i} as image",
)
if return_latents: if return_latents:
return samples, samples_z return samples, samples_z
return samples return samples
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): def get_batch(
keys,
value_dict: dict,
N: Union[List, ListConfig],
device: str = "cuda",
T: int = None,
additional_batch_uc_fields: List[str] = [],
):
# Hardcoded demo setups; might undergo some changes in the future # Hardcoded demo setups; might undergo some changes in the future
batch = {} batch = {}
@@ -605,21 +653,15 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
for key in keys: for key in keys:
if key == "txt": if key == "txt":
batch["txt"] = ( batch["txt"] = [value_dict["prompt"]] * math.prod(N)
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
.reshape(N) batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N)
.tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
elif key == "original_size_as_tuple": elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = ( batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device) .to(device)
.repeat(*N, 1) .repeat(math.prod(N), 1)
) )
elif key == "crop_coords_top_left": elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = ( batch["crop_coords_top_left"] = (
@@ -627,30 +669,67 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]] [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
) )
.to(device) .to(device)
.repeat(*N, 1) .repeat(math.prod(N), 1)
) )
elif key == "aesthetic_score": elif key == "aesthetic_score":
batch["aesthetic_score"] = ( batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) torch.tensor([value_dict["aesthetic_score"]])
.to(device)
.repeat(math.prod(N), 1)
) )
batch_uc["aesthetic_score"] = ( batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]]) torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device) .to(device)
.repeat(*N, 1) .repeat(math.prod(N), 1)
) )
elif key == "target_size_as_tuple": elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = ( batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]]) torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device) .to(device)
.repeat(*N, 1) .repeat(math.prod(N), 1)
)
elif key == "fps":
batch[key] = (
torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
)
elif key == "fps_id":
batch[key] = (
torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
)
elif key == "motion_bucket_id":
batch[key] = (
torch.tensor([value_dict["motion_bucket_id"]])
.to(device)
.repeat(math.prod(N))
)
elif key == "pool_image":
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
device, dtype=torch.half
)
elif key == "cond_aug":
batch[key] = repeat(
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
"1 -> b",
b=math.prod(N),
)
elif key == "cond_frames":
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
elif key == "cond_frames_without_noise":
batch[key] = repeat(
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
) )
else: else:
batch[key] = value_dict[key] batch[key] = value_dict[key]
if T is not None:
batch["num_video_frames"] = T
for key in batch.keys(): for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor): if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key]) batch_uc[key] = torch.clone(batch[key])
elif key in additional_batch_uc_fields and key not in batch_uc:
batch_uc[key] = copy.copy(batch[key])
return batch, batch_uc return batch, batch_uc
@@ -661,7 +740,8 @@ def do_img2img(
sampler, sampler,
value_dict, value_dict,
num_samples, num_samples,
force_uc_zero_embeddings=[], force_uc_zero_embeddings: Optional[List] = None,
force_cond_zero_embeddings: Optional[List] = None,
additional_kwargs={}, additional_kwargs={},
offset_noise_level: int = 0.0, offset_noise_level: int = 0.0,
return_latents=False, return_latents=False,
@@ -686,6 +766,7 @@ def do_img2img(
batch, batch,
batch_uc=batch_uc, batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings, force_uc_zero_embeddings=force_uc_zero_embeddings,
force_cond_zero_embeddings=force_cond_zero_embeddings,
) )
unload_model(model.conditioner) unload_model(model.conditioner)
for k in c: for k in c:
@@ -736,9 +817,112 @@ def do_img2img(
if filter is not None: if filter is not None:
samples = filter(samples) samples = filter(samples)
grid = embed_watemark(torch.stack([samples]))
grid = rearrange(grid, "n b c h w -> (n h) (b w) c") grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy()) outputs.image(grid.cpu().numpy())
if return_latents: if return_latents:
return samples, samples_z return samples, samples_z
return samples return samples
def get_resizing_factor(
desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
) -> float:
r_bound = desired_shape[1] / desired_shape[0]
aspect_r = current_shape[1] / current_shape[0]
if r_bound >= 1.0:
if aspect_r >= r_bound:
factor = min(desired_shape) / min(current_shape)
else:
if aspect_r < 1.0:
factor = max(desired_shape) / min(current_shape)
else:
factor = max(desired_shape) / max(current_shape)
else:
if aspect_r <= r_bound:
factor = min(desired_shape) / min(current_shape)
else:
if aspect_r > 1:
factor = max(desired_shape) / min(current_shape)
else:
factor = max(desired_shape) / max(current_shape)
return factor
def get_interactive_image(key=None) -> Image.Image:
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
if image is not None:
image = Image.open(image)
if not image.mode == "RGB":
image = image.convert("RGB")
return image
def load_img_for_prediction(
W: int, H: int, display=True, key=None, device="cuda"
) -> torch.Tensor:
image = get_interactive_image(key=key)
if image is None:
return None
if display:
st.image(image)
w, h = image.size
image = np.array(image).transpose(2, 0, 1)
image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0
image = image.unsqueeze(0)
rfs = get_resizing_factor((H, W), (h, w))
resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
top = (resize_size[0] - H) // 2
left = (resize_size[1] - W) // 2
image = torch.nn.functional.interpolate(
image, resize_size, mode="area", antialias=False
)
image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
if display:
numpy_img = np.transpose(image[0].numpy(), (1, 2, 0))
pil_image = Image.fromarray((numpy_img * 255).astype(np.uint8))
st.image(pil_image)
return image.to(device) * 2.0 - 1.0
def save_video_as_grid_and_mp4(
video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5
):
os.makedirs(save_path, exist_ok=True)
base_count = len(glob(os.path.join(save_path, "*.mp4")))
video_batch = rearrange(video_batch, "(b t) c h w -> b t c h w", t=T)
video_batch = embed_watermark(video_batch)
for vid in video_batch:
save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4)
video_path = os.path.join(save_path, f"{base_count:06d}.mp4")
writer = cv2.VideoWriter(
video_path,
cv2.VideoWriter_fourcc(*"MP4V"),
fps,
(vid.shape[-1], vid.shape[-2]),
)
vid = (
(rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8)
)
for frame in vid:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
writer.write(frame)
writer.release()
video_path_h264 = video_path[:-4] + "_h264.mp4"
os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}")
with open(video_path_h264, "rb") as f:
video_bytes = f.read()
st.video(video_bytes)
base_count += 1

View File

@@ -0,0 +1,200 @@
import os
from pytorch_lightning import seed_everything
from scripts.demo.streamlit_helpers import *
SAVE_PATH = "outputs/demo/vid/"
VERSION2SPECS = {
"svd": {
"T": 14,
"H": 576,
"W": 1024,
"C": 4,
"f": 8,
"config": "configs/inference/svd.yaml",
"ckpt": "checkpoints/svd.safetensors",
"options": {
"discretization": 1,
"cfg": 2.5,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
"guider": 2,
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
"num_steps": 25,
},
},
"svd_image_decoder": {
"T": 14,
"H": 576,
"W": 1024,
"C": 4,
"f": 8,
"config": "configs/inference/svd_image_decoder.yaml",
"ckpt": "checkpoints/svd_image_decoder.safetensors",
"options": {
"discretization": 1,
"cfg": 2.5,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
"guider": 2,
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
"num_steps": 25,
},
},
"svd_xt": {
"T": 25,
"H": 576,
"W": 1024,
"C": 4,
"f": 8,
"config": "configs/inference/svd.yaml",
"ckpt": "checkpoints/svd_xt.safetensors",
"options": {
"discretization": 1,
"cfg": 3.0,
"min_cfg": 1.5,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
"guider": 2,
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
"num_steps": 30,
"decoding_t": 14,
},
},
"svd_xt_image_decoder": {
"T": 25,
"H": 576,
"W": 1024,
"C": 4,
"f": 8,
"config": "configs/inference/svd_image_decoder.yaml",
"ckpt": "checkpoints/svd_xt_image_decoder.safetensors",
"options": {
"discretization": 1,
"cfg": 3.0,
"min_cfg": 1.5,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
"guider": 2,
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
"num_steps": 30,
"decoding_t": 14,
},
},
}
if __name__ == "__main__":
st.title("Stable Video Diffusion")
version = st.selectbox(
"Model Version",
[k for k in VERSION2SPECS.keys()],
0,
)
version_dict = VERSION2SPECS[version]
if st.checkbox("Load Model"):
mode = "img2vid"
else:
mode = "skip"
H = st.sidebar.number_input(
"H", value=version_dict["H"], min_value=64, max_value=2048
)
W = st.sidebar.number_input(
"W", value=version_dict["W"], min_value=64, max_value=2048
)
T = st.sidebar.number_input(
"T", value=version_dict["T"], min_value=0, max_value=128
)
C = version_dict["C"]
F = version_dict["f"]
options = version_dict["options"]
if mode != "skip":
state = init_st(version_dict, load_filter=True)
if state["msg"]:
st.info(state["msg"])
model = state["model"]
ukeys = set(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner)
)
value_dict = init_embedder_options(
ukeys,
{},
)
value_dict["image_only_indicator"] = 0
if mode == "img2vid":
img = load_img_for_prediction(W, H)
cond_aug = st.number_input(
"Conditioning augmentation:", value=0.02, min_value=0.0
)
value_dict["cond_frames_without_noise"] = img
value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img)
value_dict["cond_aug"] = cond_aug
seed = st.sidebar.number_input(
"seed", value=23, min_value=0, max_value=int(1e9)
)
seed_everything(seed)
save_locally, save_path = init_save_locally(
os.path.join(SAVE_PATH, version), init_value=True
)
options["num_frames"] = T
sampler, num_rows, num_cols = init_sampling(options=options)
num_samples = num_rows * num_cols
decoding_t = st.number_input(
"Decode t frames at a time (set small if you are low on VRAM)",
value=options.get("decoding_t", T),
min_value=1,
max_value=int(1e9),
)
if st.checkbox("Overwrite fps in mp4 generator", False):
saving_fps = st.number_input(
f"saving video at fps:", value=value_dict["fps"], min_value=1
)
else:
saving_fps = value_dict["fps"]
if st.button("Sample"):
out = do_sample(
model,
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
T=T,
batch2model_input=["num_video_frames", "image_only_indicator"],
force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None),
force_cond_zero_embeddings=options.get(
"force_cond_zero_embeddings", None
),
return_latents=False,
decoding_t=decoding_t,
)
if isinstance(out, (tuple, list)):
samples, samples_z = out
else:
samples = out
samples_z = None
if save_locally:
save_video_as_grid_and_mp4(samples, save_path, T, fps=saving_fps)

View File

@@ -0,0 +1,146 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
ckpt_path: checkpoints/svd.safetensors
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
network_config:
target: sgm.modules.diffusionmodules.video_model.VideoUNet
params:
adm_in_channels: 768
num_classes: sequential
use_checkpoint: True
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2, 1]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
extra_ff_mix_layer: True
use_spatial_context: True
merge_strategy: learned_with_images
video_kernel_size: [3, 1, 1]
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: cond_frames_without_noise
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
params:
n_cond_frames: 1
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True
- input_key: fps_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: motion_bucket_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: cond_frames
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
disable_encoder_autocast: True
n_cond_frames: 1
n_copies: 1
is_ae: True
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
- input_key: cond_aug
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencodingEngine
params:
loss_config:
target: torch.nn.Identity
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: sgm.modules.diffusionmodules.model.Encoder
params:
attn_type: vanilla
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
decoder_config:
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
params:
attn_type: vanilla
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
video_kernel_size: [3, 1, 1]
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
params:
sigma_max: 700.0
guider_config:
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
params:
max_scale: 2.5
min_scale: 1.0

View File

@@ -0,0 +1,129 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
ckpt_path: checkpoints/svd_image_decoder.safetensors
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
network_config:
target: sgm.modules.diffusionmodules.video_model.VideoUNet
params:
adm_in_channels: 768
num_classes: sequential
use_checkpoint: True
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2, 1]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
extra_ff_mix_layer: True
use_spatial_context: True
merge_strategy: learned_with_images
video_kernel_size: [3, 1, 1]
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: cond_frames_without_noise
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
params:
n_cond_frames: 1
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True
- input_key: fps_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: motion_bucket_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: cond_frames
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
disable_encoder_autocast: True
n_cond_frames: 1
n_copies: 1
is_ae: True
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
- input_key: cond_aug
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
params:
sigma_max: 700.0
guider_config:
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
params:
max_scale: 2.5
min_scale: 1.0

View File

@@ -0,0 +1,146 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
ckpt_path: checkpoints/svd_xt.safetensors
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
network_config:
target: sgm.modules.diffusionmodules.video_model.VideoUNet
params:
adm_in_channels: 768
num_classes: sequential
use_checkpoint: True
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2, 1]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
extra_ff_mix_layer: True
use_spatial_context: True
merge_strategy: learned_with_images
video_kernel_size: [3, 1, 1]
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: cond_frames_without_noise
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
params:
n_cond_frames: 1
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True
- input_key: fps_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: motion_bucket_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: cond_frames
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
disable_encoder_autocast: True
n_cond_frames: 1
n_copies: 1
is_ae: True
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
- input_key: cond_aug
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencodingEngine
params:
loss_config:
target: torch.nn.Identity
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: sgm.modules.diffusionmodules.model.Encoder
params:
attn_type: vanilla
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
decoder_config:
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
params:
attn_type: vanilla
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
video_kernel_size: [3, 1, 1]
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
params:
sigma_max: 700.0
guider_config:
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
params:
max_scale: 3.0
min_scale: 1.5

View File

@@ -0,0 +1,129 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
ckpt_path: checkpoints/svd_xt_image_decoder.safetensors
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
network_config:
target: sgm.modules.diffusionmodules.video_model.VideoUNet
params:
adm_in_channels: 768
num_classes: sequential
use_checkpoint: True
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2, 1]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
extra_ff_mix_layer: True
use_spatial_context: True
merge_strategy: learned_with_images
video_kernel_size: [3, 1, 1]
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: cond_frames_without_noise
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
params:
n_cond_frames: 1
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True
- input_key: fps_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: motion_bucket_id
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: cond_frames
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
disable_encoder_autocast: True
n_cond_frames: 1
n_copies: 1
is_ae: True
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
- input_key: cond_aug
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
params:
sigma_max: 700.0
guider_config:
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
params:
max_scale: 3.0
min_scale: 1.5

View File

@@ -0,0 +1,278 @@
import math
import os
from glob import glob
from pathlib import Path
from typing import Optional
import cv2
import numpy as np
import torch
from einops import rearrange, repeat
from fire import Fire
from omegaconf import OmegaConf
from PIL import Image
from torchvision.transforms import ToTensor
from scripts.util.detection.nsfw_and_watermark_dectection import \
DeepFloydDataFiltering
from sgm.inference.helpers import embed_watermark
from sgm.util import default, instantiate_from_config
def sample(
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
num_frames: Optional[int] = None,
num_steps: Optional[int] = None,
version: str = "svd",
fps_id: int = 6,
motion_bucket_id: int = 127,
cond_aug: float = 0.02,
seed: int = 23,
decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device: str = "cuda",
output_folder: Optional[str] = None,
):
"""
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
"""
if version == "svd":
num_frames = default(num_frames, 14)
num_steps = default(num_steps, 25)
output_folder = default(output_folder, "outputs/simple_video_sample/svd/")
model_config = "scripts/sampling/configs/svd.yaml"
elif version == "svd_xt":
num_frames = default(num_frames, 25)
num_steps = default(num_steps, 30)
output_folder = default(output_folder, "outputs/simple_video_sample/svd_xt/")
model_config = "scripts/sampling/configs/svd_xt.yaml"
elif version == "svd_image_decoder":
num_frames = default(num_frames, 14)
num_steps = default(num_steps, 25)
output_folder = default(
output_folder, "outputs/simple_video_sample/svd_image_decoder/"
)
model_config = "scripts/sampling/configs/svd_image_decoder.yaml"
elif version == "svd_xt_image_decoder":
num_frames = default(num_frames, 25)
num_steps = default(num_steps, 30)
output_folder = default(
output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/"
)
model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml"
else:
raise ValueError(f"Version {version} does not exist.")
model, filter = load_model(
model_config,
device,
num_frames,
num_steps,
)
torch.manual_seed(seed)
path = Path(input_path)
all_img_paths = []
if path.is_file():
if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
all_img_paths = [input_path]
else:
raise ValueError("Path is not valid image file.")
elif path.is_dir():
all_img_paths = sorted(
[
f
for f in path.iterdir()
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
]
)
if len(all_img_paths) == 0:
raise ValueError("Folder does not contain any images.")
else:
raise ValueError
for input_img_path in all_img_paths:
with Image.open(input_img_path) as image:
if image.mode == "RGBA":
image = image.convert("RGB")
w, h = image.size
if h % 64 != 0 or w % 64 != 0:
width, height = map(lambda x: x - x % 64, (w, h))
image = image.resize((width, height))
print(
f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
)
image = ToTensor()(image)
image = image * 2.0 - 1.0
image = image.unsqueeze(0).to(device)
H, W = image.shape[2:]
assert image.shape[1] == 3
F = 8
C = 4
shape = (num_frames, C, H // F, W // F)
if (H, W) != (576, 1024):
print(
"WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
)
if motion_bucket_id > 255:
print(
"WARNING: High motion bucket! This may lead to suboptimal performance."
)
if fps_id < 5:
print("WARNING: Small fps value! This may lead to suboptimal performance.")
if fps_id > 30:
print("WARNING: Large fps value! This may lead to suboptimal performance.")
value_dict = {}
value_dict["motion_bucket_id"] = motion_bucket_id
value_dict["fps_id"] = fps_id
value_dict["cond_aug"] = cond_aug
value_dict["cond_frames_without_noise"] = image
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
value_dict["cond_aug"] = cond_aug
with torch.no_grad():
with torch.autocast(device):
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[1, num_frames],
T=num_frames,
device=device,
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=[
"cond_frames",
"cond_frames_without_noise",
],
)
for k in ["crossattn", "concat"]:
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
randn = torch.randn(shape, device=device)
additional_model_inputs = {}
additional_model_inputs["image_only_indicator"] = torch.zeros(
2, num_frames
).to(device)
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
model.en_and_decode_n_samples_a_time = decoding_t
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
os.makedirs(output_folder, exist_ok=True)
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
writer = cv2.VideoWriter(
video_path,
cv2.VideoWriter_fourcc(*"MP4V"),
fps_id + 1,
(samples.shape[-1], samples.shape[-2]),
)
samples = embed_watermark(samples)
samples = filter(samples)
vid = (
(rearrange(samples, "t c h w -> t h w c") * 255)
.cpu()
.numpy()
.astype(np.uint8)
)
for frame in vid:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
writer.write(frame)
writer.release()
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def get_batch(keys, value_dict, N, T, device):
batch = {}
batch_uc = {}
for key in keys:
if key == "fps_id":
batch[key] = (
torch.tensor([value_dict["fps_id"]])
.to(device)
.repeat(int(math.prod(N)))
)
elif key == "motion_bucket_id":
batch[key] = (
torch.tensor([value_dict["motion_bucket_id"]])
.to(device)
.repeat(int(math.prod(N)))
)
elif key == "cond_aug":
batch[key] = repeat(
torch.tensor([value_dict["cond_aug"]]).to(device),
"1 -> b",
b=math.prod(N),
)
elif key == "cond_frames":
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
elif key == "cond_frames_without_noise":
batch[key] = repeat(
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
)
else:
batch[key] = value_dict[key]
if T is not None:
batch["num_video_frames"] = T
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
def load_model(
config: str,
device: str,
num_frames: int,
num_steps: int,
):
config = OmegaConf.load(config)
if device == "cuda":
config.model.params.conditioner_config.params.emb_models[
0
].params.open_clip_embedding_config.params.init_device = device
config.model.params.sampler_config.params.num_steps = num_steps
config.model.params.sampler_config.params.guider_config.params.num_frames = (
num_frames
)
if device == "cuda":
with torch.device(device):
model = instantiate_from_config(config.model).to(device).eval()
else:
model = instantiate_from_config(config.model).to(device).eval()
filter = DeepFloydDataFiltering(verbose=False, device=device)
return model, filter
if __name__ == "__main__":
Fire(sample)

View File

@@ -1,10 +1,10 @@
import torch
import einops import einops
from torch.backends.cuda import SDPBackend import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.benchmark as benchmark import torch.utils.benchmark as benchmark
from torch.backends.cuda import SDPBackend
from sgm.modules.attention import SpatialTransformer, BasicTransformerBlock from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer
def benchmark_attn(): def benchmark_attn():

View File

@@ -37,10 +37,13 @@ def clip_process_images(images: torch.Tensor) -> torch.Tensor:
class DeepFloydDataFiltering(object): class DeepFloydDataFiltering(object):
def __init__(self, verbose: bool = False): def __init__(
self, verbose: bool = False, device: torch.device = torch.device("cpu")
):
super().__init__() super().__init__()
self.verbose = verbose self.verbose = verbose
self.clip_model, _ = clip.load("ViT-L/14", device="cpu") self._device = None
self.clip_model, _ = clip.load("ViT-L/14", device=device)
self.clip_model.eval() self.clip_model.eval()
self.cpu_w_weights, self.cpu_w_biases = load_model_weights( self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
@@ -54,7 +57,9 @@ class DeepFloydDataFiltering(object):
@torch.inference_mode() @torch.inference_mode()
def __call__(self, images: torch.Tensor) -> torch.Tensor: def __call__(self, images: torch.Tensor) -> torch.Tensor:
imgs = clip_process_images(images) imgs = clip_process_images(images)
image_features = self.clip_model.encode_image(imgs.to("cpu")) if self._device is None:
self._device = next(p for p in self.clip_model.parameters()).device
image_features = self.clip_model.encode_image(imgs.to(self._device))
image_features = image_features.detach().cpu().numpy().astype(np.float16) image_features = image_features.detach().cpu().numpy().astype(np.float16)
p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases) p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases) w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)

View File

@@ -1,22 +1,19 @@
from dataclasses import dataclass, asdict
from enum import Enum
from omegaconf import OmegaConf
import pathlib import pathlib
from sgm.inference.helpers import ( from dataclasses import asdict, dataclass
do_sample, from enum import Enum
do_img2img, from typing import Optional
Img2ImgDiscretizationWrapper,
) from omegaconf import OmegaConf
from sgm.modules.diffusionmodules.sampling import (
from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
do_sample)
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler, EulerEDMSampler,
HeunEDMSampler, HeunEDMSampler,
EulerAncestralSampler, LinearMultistepSampler)
DPMPP2SAncestralSampler,
DPMPP2MSampler,
LinearMultistepSampler,
)
from sgm.util import load_model_from_config from sgm.util import load_model_from_config
from typing import Optional
class ModelArchitecture(str, Enum): class ModelArchitecture(str, Enum):

View File

@@ -1,13 +1,13 @@
import os
from typing import Union, List, Optional
import math import math
import os
from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from PIL import Image
from einops import rearrange from einops import rearrange
from imwatermark import WatermarkEncoder from imwatermark import WatermarkEncoder
from omegaconf import ListConfig from omegaconf import ListConfig
from PIL import Image
from torch import autocast from torch import autocast
from sgm.util import append_dims from sgm.util import append_dims
@@ -20,17 +20,16 @@ class WatermarkEmbedder:
self.encoder = WatermarkEncoder() self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark) self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor): def __call__(self, image: torch.Tensor) -> torch.Tensor:
""" """
Adds a predefined watermark to the input image Adds a predefined watermark to the input image
Args: Args:
image: ([N,] B, C, H, W) in range [0, 1] image: ([N,] B, RGB, H, W) in range [0, 1]
Returns: Returns:
same as input but watermarked same as input but watermarked
""" """
# watermarking libary expects input as cv2 BGR format
squeeze = len(image.shape) == 4 squeeze = len(image.shape) == 4
if squeeze: if squeeze:
image = image[None, ...] image = image[None, ...]
@@ -39,6 +38,7 @@ class WatermarkEmbedder:
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c" (255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1] ).numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
# watermarking libary expects input as cv2 BGR format
for k in range(image_np.shape[0]): for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct") image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy( image = torch.from_numpy(

View File

@@ -1,18 +1,22 @@
import logging
import math
import re import re
from abc import abstractmethod from abc import abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from omegaconf import ListConfig import torch.nn as nn
from einops import rearrange
from packaging import version from packaging import version
from safetensors.torch import load_file as load_safetensors
from ..modules.diffusionmodules.model import Decoder, Encoder from ..modules.autoencoding.regularizers import AbstractRegularizer
from ..modules.distributions.distributions import DiagonalGaussianDistribution
from ..modules.ema import LitEma from ..modules.ema import LitEma
from ..util import default, get_obj_from_str, instantiate_from_config from ..util import (default, get_nested_attribute, get_obj_from_str,
instantiate_from_config)
logpy = logging.getLogger(__name__)
class AbstractAutoencoder(pl.LightningModule): class AbstractAutoencoder(pl.LightningModule):
@@ -27,10 +31,9 @@ class AbstractAutoencoder(pl.LightningModule):
ema_decay: Union[None, float] = None, ema_decay: Union[None, float] = None,
monitor: Union[None, str] = None, monitor: Union[None, str] = None,
input_key: str = "jpg", input_key: str = "jpg",
ckpt_path: Union[None, str] = None,
ignore_keys: Union[Tuple, list, ListConfig] = (),
): ):
super().__init__() super().__init__()
self.input_key = input_key self.input_key = input_key
self.use_ema = ema_decay is not None self.use_ema = ema_decay is not None
if monitor is not None: if monitor is not None:
@@ -38,38 +41,21 @@ class AbstractAutoencoder(pl.LightningModule):
if self.use_ema: if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay) self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
if version.parse(torch.__version__) >= version.parse("2.0.0"): if version.parse(torch.__version__) >= version.parse("2.0.0"):
self.automatic_optimization = False self.automatic_optimization = False
def init_from_ckpt( def apply_ckpt(self, ckpt: Union[None, str, dict]):
self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple() if ckpt is None:
) -> None: return
if path.endswith("ckpt"): if isinstance(ckpt, str):
sd = torch.load(path, map_location="cpu")["state_dict"] ckpt = {
elif path.endswith("safetensors"): "target": "sgm.modules.checkpoint.CheckpointEngine",
sd = load_safetensors(path) "params": {"ckpt_path": ckpt},
else: }
raise NotImplementedError engine = instantiate_from_config(ckpt)
engine(self)
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if re.match(ik, k):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
@abstractmethod @abstractmethod
def get_input(self, batch) -> Any: def get_input(self, batch) -> Any:
@@ -86,14 +72,14 @@ class AbstractAutoencoder(pl.LightningModule):
self.model_ema.store(self.parameters()) self.model_ema.store(self.parameters())
self.model_ema.copy_to(self) self.model_ema.copy_to(self)
if context is not None: if context is not None:
print(f"{context}: Switched to EMA weights") logpy.info(f"{context}: Switched to EMA weights")
try: try:
yield None yield None
finally: finally:
if self.use_ema: if self.use_ema:
self.model_ema.restore(self.parameters()) self.model_ema.restore(self.parameters())
if context is not None: if context is not None:
print(f"{context}: Restored training weights") logpy.info(f"{context}: Restored training weights")
@abstractmethod @abstractmethod
def encode(self, *args, **kwargs) -> torch.Tensor: def encode(self, *args, **kwargs) -> torch.Tensor:
@@ -104,7 +90,7 @@ class AbstractAutoencoder(pl.LightningModule):
raise NotImplementedError("decode()-method of abstract base class called") raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg): def instantiate_optimizer_from_config(self, params, lr, cfg):
print(f"loading >>> {cfg['target']} <<< optimizer from config") logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])( return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict()) params, lr=lr, **cfg.get("params", dict())
) )
@@ -129,196 +115,435 @@ class AutoencodingEngine(AbstractAutoencoder):
regularizer_config: Dict, regularizer_config: Dict,
optimizer_config: Union[Dict, None] = None, optimizer_config: Union[Dict, None] = None,
lr_g_factor: float = 1.0, lr_g_factor: float = 1.0,
trainable_ae_params: Optional[List[List[str]]] = None,
ae_optimizer_args: Optional[List[dict]] = None,
trainable_disc_params: Optional[List[List[str]]] = None,
disc_optimizer_args: Optional[List[dict]] = None,
disc_start_iter: int = 0,
diff_boost_factor: float = 3.0,
ckpt_engine: Union[None, str, dict] = None,
ckpt_path: Optional[str] = None,
additional_decode_keys: Optional[List[str]] = None,
**kwargs, **kwargs,
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# todo: add options to freeze encoder/decoder self.automatic_optimization = False # pytorch lightning
self.encoder = instantiate_from_config(encoder_config)
self.decoder = instantiate_from_config(decoder_config) self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.loss = instantiate_from_config(loss_config) self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
self.regularization = instantiate_from_config(regularizer_config) self.loss: torch.nn.Module = instantiate_from_config(loss_config)
self.regularization: AbstractRegularizer = instantiate_from_config(
regularizer_config
)
self.optimizer_config = default( self.optimizer_config = default(
optimizer_config, {"target": "torch.optim.Adam"} optimizer_config, {"target": "torch.optim.Adam"}
) )
self.diff_boost_factor = diff_boost_factor
self.disc_start_iter = disc_start_iter
self.lr_g_factor = lr_g_factor self.lr_g_factor = lr_g_factor
self.trainable_ae_params = trainable_ae_params
if self.trainable_ae_params is not None:
self.ae_optimizer_args = default(
ae_optimizer_args,
[{} for _ in range(len(self.trainable_ae_params))],
)
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
else:
self.ae_optimizer_args = [{}] # makes type consitent
self.trainable_disc_params = trainable_disc_params
if self.trainable_disc_params is not None:
self.disc_optimizer_args = default(
disc_optimizer_args,
[{} for _ in range(len(self.trainable_disc_params))],
)
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
else:
self.disc_optimizer_args = [{}] # makes type consitent
if ckpt_path is not None:
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
self.apply_ckpt(default(ckpt_path, ckpt_engine))
self.additional_decode_keys = set(default(additional_decode_keys, []))
def get_input(self, batch: Dict) -> torch.Tensor: def get_input(self, batch: Dict) -> torch.Tensor:
# assuming unified data format, dataloader returns a dict. # assuming unified data format, dataloader returns a dict.
# image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc) # image tensors should be scaled to -1 ... 1 and in channels-first
# format (e.g., bchw instead if bhwc)
return batch[self.input_key] return batch[self.input_key]
def get_autoencoder_params(self) -> list: def get_autoencoder_params(self) -> list:
params = ( params = []
list(self.encoder.parameters()) if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
+ list(self.decoder.parameters()) params += list(self.loss.get_trainable_autoencoder_parameters())
+ list(self.regularization.get_trainable_parameters()) if hasattr(self.regularization, "get_trainable_parameters"):
+ list(self.loss.get_trainable_autoencoder_parameters()) params += list(self.regularization.get_trainable_parameters())
) params = params + list(self.encoder.parameters())
params = params + list(self.decoder.parameters())
return params return params
def get_discriminator_params(self) -> list: def get_discriminator_params(self) -> list:
if hasattr(self.loss, "get_trainable_parameters"):
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
else:
params = []
return params return params
def get_last_layer(self): def get_last_layer(self):
return self.decoder.get_last_layer() return self.decoder.get_last_layer()
def encode(self, x: Any, return_reg_log: bool = False) -> Any: def encode(
self,
x: torch.Tensor,
return_reg_log: bool = False,
unregularized: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
z = self.encoder(x) z = self.encoder(x)
if unregularized:
return z, dict()
z, reg_log = self.regularization(z) z, reg_log = self.regularization(z)
if return_reg_log: if return_reg_log:
return z, reg_log return z, reg_log
return z return z
def decode(self, z: Any) -> torch.Tensor: def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.decoder(z) x = self.decoder(z, **kwargs)
return x return x
def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def forward(
self, x: torch.Tensor, **additional_decode_kwargs
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True) z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z) dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log return z, dec, reg_log
def training_step(self, batch, batch_idx, optimizer_idx) -> Any: def inner_training_step(
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
) -> torch.Tensor:
x = self.get_input(batch) x = self.get_input(batch)
z, xrec, regularization_log = self(x) additional_decode_kwargs = {
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
}
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
if hasattr(self.loss, "forward_keys"):
extra_info = {
"z": z,
"optimizer_idx": optimizer_idx,
"global_step": self.global_step,
"last_layer": self.get_last_layer(),
"split": "train",
"regularization_log": regularization_log,
"autoencoder": self,
}
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
else:
extra_info = dict()
if optimizer_idx == 0: if optimizer_idx == 0:
# autoencode # autoencode
aeloss, log_dict_ae = self.loss( out_loss = self.loss(x, xrec, **extra_info)
regularization_log, if isinstance(out_loss, tuple):
x, aeloss, log_dict_ae = out_loss
xrec, else:
optimizer_idx, # simple loss function
self.global_step, aeloss = out_loss
last_layer=self.get_last_layer(), log_dict_ae = {"train/loss/rec": aeloss.detach()}
split="train",
)
self.log_dict( self.log_dict(
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
sync_dist=False,
)
self.log(
"loss",
aeloss.mean().detach(),
prog_bar=True,
logger=False,
on_epoch=False,
on_step=True,
) )
return aeloss return aeloss
elif optimizer_idx == 1:
if optimizer_idx == 1:
# discriminator # discriminator
discloss, log_dict_disc = self.loss( discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
regularization_log, # -> discriminator always needs to return a tuple
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log_dict( self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
) )
return discloss return discloss
else:
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
def validation_step(self, batch, batch_idx) -> Dict: def training_step(self, batch: dict, batch_idx: int):
opts = self.optimizers()
if not isinstance(opts, list):
# Non-adversarial case
opts = [opts]
optimizer_idx = batch_idx % len(opts)
if self.global_step < self.disc_start_iter:
optimizer_idx = 0
opt = opts[optimizer_idx]
opt.zero_grad()
with opt.toggle_model():
loss = self.inner_training_step(
batch, batch_idx, optimizer_idx=optimizer_idx
)
self.manual_backward(loss)
opt.step()
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
log_dict = self._validation_step(batch, batch_idx) log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope(): with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
log_dict.update(log_dict_ema) log_dict.update(log_dict_ema)
return log_dict return log_dict
def _validation_step(self, batch, batch_idx, postfix="") -> Dict: def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
x = self.get_input(batch) x = self.get_input(batch)
z, xrec, regularization_log = self(x) z, xrec, regularization_log = self(x)
aeloss, log_dict_ae = self.loss( if hasattr(self.loss, "forward_keys"):
regularization_log, extra_info = {
x, "z": z,
xrec, "optimizer_idx": 0,
0, "global_step": self.global_step,
self.global_step, "last_layer": self.get_last_layer(),
last_layer=self.get_last_layer(), "split": "val" + postfix,
split="val" + postfix, "regularization_log": regularization_log,
) "autoencoder": self,
}
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
else:
extra_info = dict()
out_loss = self.loss(x, xrec, **extra_info)
if isinstance(out_loss, tuple):
aeloss, log_dict_ae = out_loss
else:
# simple loss function
aeloss = out_loss
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
full_log_dict = log_dict_ae
discloss, log_dict_disc = self.loss( if "optimizer_idx" in extra_info:
regularization_log, extra_info["optimizer_idx"] = 1
x, discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
xrec, full_log_dict.update(log_dict_disc)
1, self.log(
self.global_step, f"val{postfix}/loss/rec",
last_layer=self.get_last_layer(), log_dict_ae[f"val{postfix}/loss/rec"],
split="val" + postfix, sync_dist=True,
) )
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) self.log_dict(full_log_dict, sync_dist=True)
log_dict_ae.update(log_dict_disc) return full_log_dict
self.log_dict(log_dict_ae)
return log_dict_ae
def configure_optimizers(self) -> Any: def get_param_groups(
self, parameter_names: List[List[str]], optimizer_args: List[dict]
) -> Tuple[List[Dict[str, Any]], int]:
groups = []
num_params = 0
for names, args in zip(parameter_names, optimizer_args):
params = []
for pattern_ in names:
pattern_params = []
pattern = re.compile(pattern_)
for p_name, param in self.named_parameters():
if re.match(pattern, p_name):
pattern_params.append(param)
num_params += param.numel()
if len(pattern_params) == 0:
logpy.warn(f"Did not find parameters for pattern {pattern_}")
params.extend(pattern_params)
groups.append({"params": params, **args})
return groups, num_params
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
if self.trainable_ae_params is None:
ae_params = self.get_autoencoder_params() ae_params = self.get_autoencoder_params()
else:
ae_params, num_ae_params = self.get_param_groups(
self.trainable_ae_params, self.ae_optimizer_args
)
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
if self.trainable_disc_params is None:
disc_params = self.get_discriminator_params() disc_params = self.get_discriminator_params()
else:
disc_params, num_disc_params = self.get_param_groups(
self.trainable_disc_params, self.disc_optimizer_args
)
logpy.info(
f"Number of trainable discriminator parameters: {num_disc_params:,}"
)
opt_ae = self.instantiate_optimizer_from_config( opt_ae = self.instantiate_optimizer_from_config(
ae_params, ae_params,
default(self.lr_g_factor, 1.0) * self.learning_rate, default(self.lr_g_factor, 1.0) * self.learning_rate,
self.optimizer_config, self.optimizer_config,
) )
opts = [opt_ae]
if len(disc_params) > 0:
opt_disc = self.instantiate_optimizer_from_config( opt_disc = self.instantiate_optimizer_from_config(
disc_params, self.learning_rate, self.optimizer_config disc_params, self.learning_rate, self.optimizer_config
) )
opts.append(opt_disc)
return [opt_ae, opt_disc], [] return opts
@torch.no_grad() @torch.no_grad()
def log_images(self, batch: Dict, **kwargs) -> Dict: def log_images(
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
) -> dict:
log = dict() log = dict()
additional_decode_kwargs = {}
x = self.get_input(batch) x = self.get_input(batch)
_, xrec, _ = self(x) additional_decode_kwargs.update(
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
)
_, xrec, _ = self(x, **additional_decode_kwargs)
log["inputs"] = x log["inputs"] = x
log["reconstructions"] = xrec log["reconstructions"] = xrec
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
diff.clamp_(0, 1.0)
log["diff"] = 2.0 * diff - 1.0
# diff_boost shows location of small errors, by boosting their
# brightness.
log["diff_boost"] = (
2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
)
if hasattr(self.loss, "log_images"):
log.update(self.loss.log_images(x, xrec))
with self.ema_scope(): with self.ema_scope():
_, xrec_ema, _ = self(x) _, xrec_ema, _ = self(x, **additional_decode_kwargs)
log["reconstructions_ema"] = xrec_ema log["reconstructions_ema"] = xrec_ema
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
diff_ema.clamp_(0, 1.0)
log["diff_ema"] = 2.0 * diff_ema - 1.0
log["diff_boost_ema"] = (
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
)
if additional_log_kwargs:
additional_decode_kwargs.update(additional_log_kwargs)
_, xrec_add, _ = self(x, **additional_decode_kwargs)
log_str = "reconstructions-" + "-".join(
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
)
log[log_str] = xrec_add
return log return log
class AutoencoderKL(AutoencodingEngine): class AutoencodingEngineLegacy(AutoencodingEngine):
def __init__(self, embed_dim: int, **kwargs): def __init__(self, embed_dim: int, **kwargs):
self.max_batch_size = kwargs.pop("max_batch_size", None)
ddconfig = kwargs.pop("ddconfig") ddconfig = kwargs.pop("ddconfig")
ckpt_path = kwargs.pop("ckpt_path", None) ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", ()) ckpt_engine = kwargs.pop("ckpt_engine", None)
super().__init__( super().__init__(
encoder_config={"target": "torch.nn.Identity"}, encoder_config={
decoder_config={"target": "torch.nn.Identity"}, "target": "sgm.modules.diffusionmodules.model.Encoder",
regularizer_config={"target": "torch.nn.Identity"}, "params": ddconfig,
loss_config=kwargs.pop("lossconfig"), },
decoder_config={
"target": "sgm.modules.diffusionmodules.model.Decoder",
"params": ddconfig,
},
**kwargs, **kwargs,
) )
assert ddconfig["double_z"] self.quant_conv = torch.nn.Conv2d(
self.encoder = Encoder(**ddconfig) (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
self.decoder = Decoder(**ddconfig) (1 + ddconfig["double_z"]) * embed_dim,
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) 1,
)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim self.embed_dim = embed_dim
if ckpt_path is not None: self.apply_ckpt(default(ckpt_path, ckpt_engine))
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def encode(self, x): def get_autoencoder_params(self) -> list:
assert ( params = super().get_autoencoder_params()
not self.training return params
), f"{self.__class__.__name__} only supports inference currently"
h = self.encoder(x) def encode(
moments = self.quant_conv(h) self, x: torch.Tensor, return_reg_log: bool = False
posterior = DiagonalGaussianDistribution(moments) ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
return posterior if self.max_batch_size is None:
z = self.encoder(x)
z = self.quant_conv(z)
else:
N = x.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
z = list()
for i_batch in range(n_batches):
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
z_batch = self.quant_conv(z_batch)
z.append(z_batch)
z = torch.cat(z, 0)
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.max_batch_size is None:
dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs)
else:
N = z.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
dec = list()
for i_batch in range(n_batches):
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
dec.append(dec_batch)
dec = torch.cat(dec, 0)
def decode(self, z, **decoder_kwargs):
z = self.post_quant_conv(z)
dec = self.decoder(z, **decoder_kwargs)
return dec return dec
class AutoencoderKLInferenceWrapper(AutoencoderKL): class AutoencoderKL(AutoencodingEngineLegacy):
def encode(self, x): def __init__(self, **kwargs):
return super().encode(x).sample() if "lossconfig" in kwargs:
kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__(
regularizer_config={
"target": (
"sgm.modules.autoencoding.regularizers"
".DiagonalGaussianRegularizer"
)
},
**kwargs,
)
class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
def __init__(
self,
embed_dim: int,
n_embed: int,
sane_index_shape: bool = False,
**kwargs,
):
if "lossconfig" in kwargs:
logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__(
regularizer_config={
"target": (
"sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
),
"params": {
"n_e": n_embed,
"e_dim": embed_dim,
"sane_index_shape": sane_index_shape,
},
},
**kwargs,
)
class IdentityFirstStage(AbstractAutoencoder): class IdentityFirstStage(AbstractAutoencoder):
@@ -333,3 +558,58 @@ class IdentityFirstStage(AbstractAutoencoder):
def decode(self, x: Any, *args, **kwargs) -> Any: def decode(self, x: Any, *args, **kwargs) -> Any:
return x return x
class AEIntegerWrapper(nn.Module):
def __init__(
self,
model: nn.Module,
shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
regularization_key: str = "regularization",
encoder_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__()
self.model = model
assert hasattr(model, "encode") and hasattr(
model, "decode"
), "Need AE interface"
self.regularization = get_nested_attribute(model, regularization_key)
self.shape = shape
self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
def encode(self, x) -> torch.Tensor:
assert (
not self.training
), f"{self.__class__.__name__} only supports inference currently"
_, log = self.model.encode(x, **self.encoder_kwargs)
assert isinstance(log, dict)
inds = log["min_encoding_indices"]
return rearrange(inds, "b ... -> b (...)")
def decode(
self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
) -> torch.Tensor:
# expect inds shape (b, s) with s = h*w
shape = default(shape, self.shape) # Optional[(h, w)]
if shape is not None:
assert len(shape) == 2, f"Unhandeled shape {shape}"
inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
h = rearrange(h, "b h w c -> b c h w")
return self.model.decode(h)
class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
def __init__(self, **kwargs):
if "lossconfig" in kwargs:
kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__(
regularizer_config={
"target": (
"sgm.modules.autoencoding.regularizers"
".DiagonalGaussianRegularizer"
),
"params": {"sample": False},
},
**kwargs,
)

View File

@@ -1,5 +1,6 @@
import math
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
@@ -8,15 +9,11 @@ from safetensors.torch import load_file as load_safetensors
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from ..modules import UNCONDITIONAL_CONFIG from ..modules import UNCONDITIONAL_CONFIG
from ..modules.autoencoding.temporal_ae import VideoDecoder
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from ..modules.ema import LitEma from ..modules.ema import LitEma
from ..util import ( from ..util import (default, disabled_train, get_obj_from_str,
default, instantiate_from_config, log_txt_as_img)
disabled_train,
get_obj_from_str,
instantiate_from_config,
log_txt_as_img,
)
class DiffusionEngine(pl.LightningModule): class DiffusionEngine(pl.LightningModule):
@@ -40,6 +37,7 @@ class DiffusionEngine(pl.LightningModule):
log_keys: Union[List, None] = None, log_keys: Union[List, None] = None,
no_cond_log: bool = False, no_cond_log: bool = False,
compile_model: bool = False, compile_model: bool = False,
en_and_decode_n_samples_a_time: Optional[int] = None,
): ):
super().__init__() super().__init__()
self.log_keys = log_keys self.log_keys = log_keys
@@ -82,6 +80,8 @@ class DiffusionEngine(pl.LightningModule):
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path) self.init_from_ckpt(ckpt_path)
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
def init_from_ckpt( def init_from_ckpt(
self, self,
path: str, path: str,
@@ -117,14 +117,35 @@ class DiffusionEngine(pl.LightningModule):
@torch.no_grad() @torch.no_grad()
def decode_first_stage(self, z): def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z z = 1.0 / self.scale_factor * z
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
out = self.first_stage_model.decode(z) for n in range(n_rounds):
if isinstance(self.first_stage_model.decoder, VideoDecoder):
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
else:
kwargs = {}
out = self.first_stage_model.decode(
z[n * n_samples : (n + 1) * n_samples], **kwargs
)
all_out.append(out)
out = torch.cat(all_out, dim=0)
return out return out
@torch.no_grad() @torch.no_grad()
def encode_first_stage(self, x): def encode_first_stage(self, x):
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
n_rounds = math.ceil(x.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
z = self.first_stage_model.encode(x) for n in range(n_rounds):
out = self.first_stage_model.encode(
x[n * n_samples : (n + 1) * n_samples]
)
all_out.append(out)
z = torch.cat(all_out, dim=0)
z = self.scale_factor * z z = self.scale_factor * z
return z return z

View File

@@ -1,3 +1,4 @@
import logging
import math import math
from inspect import isfunction from inspect import isfunction
from typing import Any, Optional from typing import Any, Optional
@@ -7,6 +8,9 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.utils.checkpoint import checkpoint
logpy = logging.getLogger(__name__)
if version.parse(torch.__version__) >= version.parse("2.0.0"): if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True SDP_IS_AVAILABLE = True
@@ -36,9 +40,10 @@ else:
SDP_IS_AVAILABLE = False SDP_IS_AVAILABLE = False
sdp_kernel = nullcontext sdp_kernel = nullcontext
BACKEND_MAP = {} BACKEND_MAP = {}
print( logpy.warn(
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, " f"No SDP backend available, likely because you are running in pytorch "
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading." f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
f"You might want to consider upgrading."
) )
try: try:
@@ -48,9 +53,9 @@ try:
XFORMERS_IS_AVAILABLE = True XFORMERS_IS_AVAILABLE = True
except: except:
XFORMERS_IS_AVAILABLE = False XFORMERS_IS_AVAILABLE = False
print("no module 'xformers'. Processing without...") logpy.warn("no module 'xformers'. Processing without...")
from .diffusionmodules.util import checkpoint # from .diffusionmodules.util import mixed_checkpoint as checkpoint
def exists(val): def exists(val):
@@ -146,6 +151,62 @@ class LinearAttention(nn.Module):
return self.to_out(out) return self.to_out(out)
class SelfAttention(nn.Module):
ATTENTION_MODES = ("xformers", "torch", "math")
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: Optional[float] = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
attn_mode: str = "xformers",
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
assert attn_mode in self.ATTENTION_MODES
self.attn_mode = attn_mode
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, L, C = x.shape
qkv = self.qkv(x)
if self.attn_mode == "torch":
qkv = rearrange(
qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
).float()
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
elif self.attn_mode == "xformers":
qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
x = xformers.ops.memory_efficient_attention(q, k, v)
x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
elif self.attn_mode == "math":
qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
else:
raise NotImplemented
x = self.proj(x)
x = self.proj_drop(x)
return x
class SpatialSelfAttention(nn.Module): class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super().__init__() super().__init__()
@@ -289,9 +350,10 @@ class MemoryEfficientCrossAttention(nn.Module):
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
): ):
super().__init__() super().__init__()
print( logpy.debug(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
f"{heads} heads with a dimension of {dim_head}." f"context_dim is {context_dim} and using {heads} heads with a "
f"dimension of {dim_head}."
) )
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@@ -352,6 +414,26 @@ class MemoryEfficientCrossAttention(nn.Module):
) )
# actually compute the attention, what we cannot get enough of # actually compute the attention, what we cannot get enough of
if version.parse(xformers.__version__) >= version.parse("0.0.21"):
# NOTE: workaround for
# https://github.com/facebookresearch/xformers/issues/845
max_bs = 32768
N = q.shape[0]
n_batches = math.ceil(N / max_bs)
out = list()
for i_batch in range(n_batches):
batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
out.append(
xformers.ops.memory_efficient_attention(
q[batch],
k[batch],
v[batch],
attn_bias=None,
op=self.attention_op,
)
)
out = torch.cat(out, 0)
else:
out = xformers.ops.memory_efficient_attention( out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op q, k, v, attn_bias=None, op=self.attention_op
) )
@@ -393,21 +475,24 @@ class BasicTransformerBlock(nn.Module):
super().__init__() super().__init__()
assert attn_mode in self.ATTENTION_MODES assert attn_mode in self.ATTENTION_MODES
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE: if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
print( logpy.warn(
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. " f"Attention mode '{attn_mode}' is not available. Falling "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" f"back to native attention. This is not a problem in "
f"Pytorch >= 2.0. FYI, you are running with PyTorch "
f"version {torch.__version__}."
) )
attn_mode = "softmax" attn_mode = "softmax"
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE: elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
print( logpy.warn(
"We do not support vanilla attention anymore, as it is too expensive. Sorry." "We do not support vanilla attention anymore, as it is too "
"expensive. Sorry."
) )
if not XFORMERS_IS_AVAILABLE: if not XFORMERS_IS_AVAILABLE:
assert ( assert (
False False
), "Please install xformers via e.g. 'pip install xformers==0.0.16'" ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
else: else:
print("Falling back to xformers efficient attention.") logpy.info("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers" attn_mode = "softmax-xformers"
attn_cls = self.ATTENTION_MODES[attn_mode] attn_cls = self.ATTENTION_MODES[attn_mode]
if version.parse(torch.__version__) >= version.parse("2.0.0"): if version.parse(torch.__version__) >= version.parse("2.0.0"):
@@ -437,7 +522,7 @@ class BasicTransformerBlock(nn.Module):
self.norm3 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint self.checkpoint = checkpoint
if self.checkpoint: if self.checkpoint:
print(f"{self.__class__.__name__} is using checkpointing") logpy.debug(f"{self.__class__.__name__} is using checkpointing")
def forward( def forward(
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
@@ -456,9 +541,12 @@ class BasicTransformerBlock(nn.Module):
) )
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint) # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
return checkpoint( if self.checkpoint:
self._forward, (x, context), self.parameters(), self.checkpoint # inputs = {"x": x, "context": context}
) return checkpoint(self._forward, x, context)
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
else:
return self._forward(**kwargs)
def _forward( def _forward(
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
@@ -518,9 +606,9 @@ class BasicTransformerSingleLayerBlock(nn.Module):
self.checkpoint = checkpoint self.checkpoint = checkpoint
def forward(self, x, context=None): def forward(self, x, context=None):
return checkpoint( # inputs = {"x": x, "context": context}
self._forward, (x, context), self.parameters(), self.checkpoint # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
) return checkpoint(self._forward, x, context)
def _forward(self, x, context=None): def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context) + x x = self.attn1(self.norm1(x), context=context) + x
@@ -554,18 +642,20 @@ class SpatialTransformer(nn.Module):
sdp_backend=None, sdp_backend=None,
): ):
super().__init__() super().__init__()
print( logpy.debug(
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads" f"constructing {self.__class__.__name__} of depth {depth} w/ "
f"{in_channels} channels and {n_heads} heads."
) )
from omegaconf import ListConfig
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)): if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] context_dim = [context_dim]
if exists(context_dim) and isinstance(context_dim, list): if exists(context_dim) and isinstance(context_dim, list):
if depth != len(context_dim): if depth != len(context_dim):
print( logpy.warn(
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, " f"{self.__class__.__name__}: Found context dims "
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now." f"{context_dim} of depth {len(context_dim)}, which does not "
f"match the specified 'depth' of {depth}. Setting context_dim "
f"to {depth * [context_dim[0]]} now."
) )
# depth does not match context dims. # depth does not match context dims.
assert all( assert all(
@@ -631,3 +721,39 @@ class SpatialTransformer(nn.Module):
if not self.use_linear: if not self.use_linear:
x = self.proj_out(x) x = self.proj_out(x)
return x + x_in return x + x_in
class SimpleTransformer(nn.Module):
def __init__(
self,
dim: int,
depth: int,
heads: int,
dim_head: int,
context_dim: Optional[int] = None,
dropout: float = 0.0,
checkpoint: bool = True,
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
BasicTransformerBlock(
dim,
heads,
dim_head,
dropout=dropout,
context_dim=context_dim,
attn_mode="softmax-xformers",
checkpoint=checkpoint,
)
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for layer in self.layers:
x = layer(x, context)
return x

View File

@@ -1,246 +1,7 @@
from typing import Any, Union __all__ = [
"GeneralLPIPSWithDiscriminator",
"LatentLPIPS",
]
import torch from .discriminator_loss import GeneralLPIPSWithDiscriminator
import torch.nn as nn from .lpips import LatentLPIPS
from einops import rearrange
from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS
from ..lpips.model.model import NLayerDiscriminator, weights_init
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
class LatentLPIPS(nn.Module):
def __init__(
self,
decoder_config,
perceptual_weight=1.0,
latent_weight=1.0,
scale_input_to_tgt_size=False,
scale_tgt_to_input_size=False,
perceptual_weight_on_inputs=0.0,
):
super().__init__()
self.scale_input_to_tgt_size = scale_input_to_tgt_size
self.scale_tgt_to_input_size = scale_tgt_to_input_size
self.init_decoder(decoder_config)
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
self.latent_weight = latent_weight
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
def init_decoder(self, config):
self.decoder = instantiate_from_config(config)
if hasattr(self.decoder, "encoder"):
del self.decoder.encoder
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
log = dict()
loss = (latent_inputs - latent_predictions) ** 2
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
image_reconstructions = None
if self.perceptual_weight > 0.0:
image_reconstructions = self.decoder.decode(latent_predictions)
image_targets = self.decoder.decode(latent_inputs)
perceptual_loss = self.perceptual_loss(
image_targets.contiguous(), image_reconstructions.contiguous()
)
loss = (
self.latent_weight * loss.mean()
+ self.perceptual_weight * perceptual_loss.mean()
)
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
if self.perceptual_weight_on_inputs > 0.0:
image_reconstructions = default(
image_reconstructions, self.decoder.decode(latent_predictions)
)
if self.scale_input_to_tgt_size:
image_inputs = torch.nn.functional.interpolate(
image_inputs,
image_reconstructions.shape[2:],
mode="bicubic",
antialias=True,
)
elif self.scale_tgt_to_input_size:
image_reconstructions = torch.nn.functional.interpolate(
image_reconstructions,
image_inputs.shape[2:],
mode="bicubic",
antialias=True,
)
perceptual_loss2 = self.perceptual_loss(
image_inputs.contiguous(), image_reconstructions.contiguous()
)
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
return loss, log
class GeneralLPIPSWithDiscriminator(nn.Module):
def __init__(
self,
disc_start: int,
logvar_init: float = 0.0,
pixelloss_weight=1.0,
disc_num_layers: int = 3,
disc_in_channels: int = 3,
disc_factor: float = 1.0,
disc_weight: float = 1.0,
perceptual_weight: float = 1.0,
disc_loss: str = "hinge",
scale_input_to_tgt_size: bool = False,
dims: int = 2,
learn_logvar: bool = False,
regularization_weights: Union[None, dict] = None,
):
super().__init__()
self.dims = dims
if self.dims > 2:
print(
f"running with dims={dims}. This means that for perceptual loss calculation, "
f"the LPIPS loss will be applied to each frame independently. "
)
self.scale_input_to_tgt_size = scale_input_to_tgt_size
assert disc_loss in ["hinge", "vanilla"]
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.learn_logvar = learn_logvar
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False
).apply(weights_init)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.regularization_weights = default(regularization_weights, {})
def get_trainable_parameters(self) -> Any:
return self.discriminator.parameters()
def get_trainable_autoencoder_parameters(self) -> Any:
if self.learn_logvar:
yield self.logvar
yield from ()
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
regularization_log,
inputs,
reconstructions,
optimizer_idx,
global_step,
last_layer=None,
split="train",
weights=None,
):
if self.scale_input_to_tgt_size:
inputs = torch.nn.functional.interpolate(
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
)
if self.dims > 2:
inputs, reconstructions = map(
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
(inputs, reconstructions),
)
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
# now the GAN part
if optimizer_idx == 0:
# generator update
logits_fake = self.discriminator(reconstructions.contiguous())
g_loss = -torch.mean(logits_fake)
if self.disc_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
loss = weighted_nll_loss + d_weight * disc_factor * g_loss
log = dict()
for k in regularization_log:
if k in self.regularization_weights:
loss = loss + self.regularization_weights[k] * regularization_log[k]
log[f"{split}/{k}"] = regularization_log[k].detach().mean()
log.update(
{
"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/logvar".format(split): self.logvar.detach(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
)
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean(),
}
return d_loss, log

View File

@@ -0,0 +1,306 @@
from typing import Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torchvision
from einops import rearrange
from matplotlib import colormaps
from matplotlib import pyplot as plt
from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS
from ..lpips.model.model import weights_init
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
class GeneralLPIPSWithDiscriminator(nn.Module):
def __init__(
self,
disc_start: int,
logvar_init: float = 0.0,
disc_num_layers: int = 3,
disc_in_channels: int = 3,
disc_factor: float = 1.0,
disc_weight: float = 1.0,
perceptual_weight: float = 1.0,
disc_loss: str = "hinge",
scale_input_to_tgt_size: bool = False,
dims: int = 2,
learn_logvar: bool = False,
regularization_weights: Union[None, Dict[str, float]] = None,
additional_log_keys: Optional[List[str]] = None,
discriminator_config: Optional[Dict] = None,
):
super().__init__()
self.dims = dims
if self.dims > 2:
print(
f"running with dims={dims}. This means that for perceptual loss "
f"calculation, the LPIPS loss will be applied to each frame "
f"independently."
)
self.scale_input_to_tgt_size = scale_input_to_tgt_size
assert disc_loss in ["hinge", "vanilla"]
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# output log variance
self.logvar = nn.Parameter(
torch.full((), logvar_init), requires_grad=learn_logvar
)
self.learn_logvar = learn_logvar
discriminator_config = default(
discriminator_config,
{
"target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
"params": {
"input_nc": disc_in_channels,
"n_layers": disc_num_layers,
"use_actnorm": False,
},
},
)
self.discriminator = instantiate_from_config(discriminator_config).apply(
weights_init
)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.regularization_weights = default(regularization_weights, {})
self.forward_keys = [
"optimizer_idx",
"global_step",
"last_layer",
"split",
"regularization_log",
]
self.additional_log_keys = set(default(additional_log_keys, []))
self.additional_log_keys.update(set(self.regularization_weights.keys()))
def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
return self.discriminator.parameters()
def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
if self.learn_logvar:
yield self.logvar
yield from ()
@torch.no_grad()
def log_images(
self, inputs: torch.Tensor, reconstructions: torch.Tensor
) -> Dict[str, torch.Tensor]:
# calc logits of real/fake
logits_real = self.discriminator(inputs.contiguous().detach())
if len(logits_real.shape) < 4:
# Non patch-discriminator
return dict()
logits_fake = self.discriminator(reconstructions.contiguous().detach())
# -> (b, 1, h, w)
# parameters for colormapping
high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
cmap = colormaps["PiYG"] # diverging colormap
def to_colormap(logits: torch.Tensor) -> torch.Tensor:
"""(b, 1, ...) -> (b, 3, ...)"""
logits = (logits + high) / (2 * high)
logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
# -> (b, 1, ..., 3)
logits = torch.from_numpy(logits_np).to(logits.device)
return rearrange(logits, "b 1 ... c -> b c ...")
logits_real = torch.nn.functional.interpolate(
logits_real,
size=inputs.shape[-2:],
mode="nearest",
antialias=False,
)
logits_fake = torch.nn.functional.interpolate(
logits_fake,
size=reconstructions.shape[-2:],
mode="nearest",
antialias=False,
)
# alpha value of logits for overlay
alpha_real = torch.abs(logits_real) / high
alpha_fake = torch.abs(logits_fake) / high
# -> (b, 1, h, w) in range [0, 0.5]
# alpha value of lines don't really matter, since the values are the same
# for both images and logits anyway
grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
# -> (1, h, w)
# blend logits and images together
# prepare logits for plotting
logits_real = to_colormap(logits_real)
logits_fake = to_colormap(logits_fake)
# resize logits
# -> (b, 3, h, w)
# make some grids
# add all logits to one plot
logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
# I just love how torchvision calls the number of columns `nrow`
grid_logits = torch.cat((logits_real, logits_fake), dim=1)
# -> (3, h, w)
grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
grid_images_fake = torchvision.utils.make_grid(
0.5 * reconstructions + 0.5, nrow=4
)
grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
# -> (3, h, w) in range [0, 1]
grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
# Create labeled colorbar
dpi = 100
height = 128 / dpi
width = grid_logits.shape[2] / dpi
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
plt.colorbar(
img,
cax=ax,
orientation="horizontal",
fraction=0.9,
aspect=width / height,
pad=0.0,
)
img.set_visible(False)
fig.tight_layout()
fig.canvas.draw()
# manually convert figure to numpy
cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
# Add colorbar to plot
annotated_grid = torch.cat((grid_logits, cbar), dim=1)
blended_grid = torch.cat((grid_blend, cbar), dim=1)
return {
"vis_logits": 2 * annotated_grid[None, ...] - 1,
"vis_logits_blended": 2 * blended_grid[None, ...] - 1,
}
def calculate_adaptive_weight(
self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
) -> torch.Tensor:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
inputs: torch.Tensor,
reconstructions: torch.Tensor,
*, # added because I changed the order here
regularization_log: Dict[str, torch.Tensor],
optimizer_idx: int,
global_step: int,
last_layer: torch.Tensor,
split: str = "train",
weights: Union[None, float, torch.Tensor] = None,
) -> Tuple[torch.Tensor, dict]:
if self.scale_input_to_tgt_size:
inputs = torch.nn.functional.interpolate(
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
)
if self.dims > 2:
inputs, reconstructions = map(
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
(inputs, reconstructions),
)
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
# now the GAN part
if optimizer_idx == 0:
# generator update
if global_step >= self.discriminator_iter_start or not self.training:
logits_fake = self.discriminator(reconstructions.contiguous())
g_loss = -torch.mean(logits_fake)
if self.training:
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
else:
d_weight = torch.tensor(1.0)
else:
d_weight = torch.tensor(0.0)
g_loss = torch.tensor(0.0, requires_grad=True)
loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
log = dict()
for k in regularization_log:
if k in self.regularization_weights:
loss = loss + self.regularization_weights[k] * regularization_log[k]
if k in self.additional_log_keys:
log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
log.update(
{
f"{split}/loss/total": loss.clone().detach().mean(),
f"{split}/loss/nll": nll_loss.detach().mean(),
f"{split}/loss/rec": rec_loss.detach().mean(),
f"{split}/loss/g": g_loss.detach().mean(),
f"{split}/scalars/logvar": self.logvar.detach(),
f"{split}/scalars/d_weight": d_weight.detach(),
}
)
return loss, log
elif optimizer_idx == 1:
# second pass for discriminator update
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
if global_step >= self.discriminator_iter_start or not self.training:
d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
else:
d_loss = torch.tensor(0.0, requires_grad=True)
log = {
f"{split}/loss/disc": d_loss.clone().detach().mean(),
f"{split}/logits/real": logits_real.detach().mean(),
f"{split}/logits/fake": logits_fake.detach().mean(),
}
return d_loss, log
else:
raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
def get_nll_loss(
self,
rec_loss: torch.Tensor,
weights: Optional[Union[float, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
return nll_loss, weighted_nll_loss

View File

@@ -0,0 +1,73 @@
import torch
import torch.nn as nn
from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS
class LatentLPIPS(nn.Module):
def __init__(
self,
decoder_config,
perceptual_weight=1.0,
latent_weight=1.0,
scale_input_to_tgt_size=False,
scale_tgt_to_input_size=False,
perceptual_weight_on_inputs=0.0,
):
super().__init__()
self.scale_input_to_tgt_size = scale_input_to_tgt_size
self.scale_tgt_to_input_size = scale_tgt_to_input_size
self.init_decoder(decoder_config)
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
self.latent_weight = latent_weight
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
def init_decoder(self, config):
self.decoder = instantiate_from_config(config)
if hasattr(self.decoder, "encoder"):
del self.decoder.encoder
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
log = dict()
loss = (latent_inputs - latent_predictions) ** 2
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
image_reconstructions = None
if self.perceptual_weight > 0.0:
image_reconstructions = self.decoder.decode(latent_predictions)
image_targets = self.decoder.decode(latent_inputs)
perceptual_loss = self.perceptual_loss(
image_targets.contiguous(), image_reconstructions.contiguous()
)
loss = (
self.latent_weight * loss.mean()
+ self.perceptual_weight * perceptual_loss.mean()
)
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
if self.perceptual_weight_on_inputs > 0.0:
image_reconstructions = default(
image_reconstructions, self.decoder.decode(latent_predictions)
)
if self.scale_input_to_tgt_size:
image_inputs = torch.nn.functional.interpolate(
image_inputs,
image_reconstructions.shape[2:],
mode="bicubic",
antialias=True,
)
elif self.scale_tgt_to_input_size:
image_reconstructions = torch.nn.functional.interpolate(
image_reconstructions,
image_inputs.shape[2:],
mode="bicubic",
antialias=True,
)
perceptual_loss2 = self.perceptual_loss(
image_inputs.contiguous(), image_reconstructions.contiguous()
)
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
return loss, log

View File

@@ -5,19 +5,9 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ....modules.distributions.distributions import DiagonalGaussianDistribution from ....modules.distributions.distributions import \
DiagonalGaussianDistribution
from .base import AbstractRegularizer
class AbstractRegularizer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
raise NotImplementedError()
@abstractmethod
def get_trainable_parameters(self) -> Any:
raise NotImplementedError()
class DiagonalGaussianRegularizer(AbstractRegularizer): class DiagonalGaussianRegularizer(AbstractRegularizer):
@@ -39,15 +29,3 @@ class DiagonalGaussianRegularizer(AbstractRegularizer):
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss log["kl_loss"] = kl_loss
return z, log return z, log
def measure_perplexity(predicted_indices, num_centroids):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = (
F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use

View File

@@ -0,0 +1,40 @@
from abc import abstractmethod
from typing import Any, Tuple
import torch
import torch.nn.functional as F
from torch import nn
class AbstractRegularizer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
raise NotImplementedError()
@abstractmethod
def get_trainable_parameters(self) -> Any:
raise NotImplementedError()
class IdentityRegularizer(AbstractRegularizer):
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
return z, dict()
def get_trainable_parameters(self) -> Any:
yield from ()
def measure_perplexity(
predicted_indices: torch.Tensor, num_centroids: int
) -> Tuple[torch.Tensor, torch.Tensor]:
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = (
F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use

View File

@@ -0,0 +1,487 @@
import logging
from abc import abstractmethod
from typing import Dict, Iterator, Literal, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import einsum
from .base import AbstractRegularizer, measure_perplexity
logpy = logging.getLogger(__name__)
class AbstractQuantizer(AbstractRegularizer):
def __init__(self):
super().__init__()
# Define these in your init
# shape (N,)
self.used: Optional[torch.Tensor]
self.re_embed: int
self.unknown_index: Union[Literal["random"], int]
def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor:
assert self.used is not None, "You need to define used indices for remap"
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
match = (inds[:, :, None] == used[None, None, ...]).long()
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
device=new.device
)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor:
assert self.used is not None, "You need to define used indices for remap"
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
@abstractmethod
def get_codebook_entry(
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
) -> torch.Tensor:
raise NotImplementedError()
def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
yield from self.parameters()
class GumbelQuantizer(AbstractQuantizer):
"""
credit to @karpathy:
https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
Gumbel Softmax trick quantizer
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
https://arxiv.org/abs/1611.01144
"""
def __init__(
self,
num_hiddens: int,
embedding_dim: int,
n_embed: int,
straight_through: bool = True,
kl_weight: float = 5e-4,
temp_init: float = 1.0,
remap: Optional[str] = None,
unknown_index: str = "random",
loss_key: str = "loss/vq",
) -> None:
super().__init__()
self.loss_key = loss_key
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
self.embed = nn.Embedding(n_embed, embedding_dim)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
else:
self.used = None
self.re_embed = n_embed
if unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
else:
assert unknown_index == "random" or isinstance(
unknown_index, int
), "unknown index needs to be 'random', 'extra' or any integer"
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.remap is not None:
logpy.info(
f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices."
)
def forward(
self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False
) -> Tuple[torch.Tensor, Dict]:
# force hard = True when we are in eval mode, as we must quantize.
# actually, always true seems to work
hard = self.straight_through if self.training else True
temp = self.temperature if temp is None else temp
out_dict = {}
logits = self.proj(z)
if self.remap is not None:
# continue only with used logits
full_zeros = torch.zeros_like(logits)
logits = logits[:, self.used, ...]
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
if self.remap is not None:
# go back to all entries but unused set to zero
full_zeros[:, self.used, ...] = soft_one_hot
soft_one_hot = full_zeros
z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = (
self.kl_weight
* torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
)
out_dict[self.loss_key] = diff
ind = soft_one_hot.argmax(dim=1)
out_dict["indices"] = ind
if self.remap is not None:
ind = self.remap_to_used(ind)
if return_logits:
out_dict["logits"] = logits
return z_q, out_dict
def get_codebook_entry(self, indices, shape):
# TODO: shape not yet optional
b, h, w, c = shape
assert b * h * w == indices.shape[0]
indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
if self.remap is not None:
indices = self.unmap_to_all(indices)
one_hot = (
F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
)
z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
return z_q
class VectorQuantizer(AbstractQuantizer):
"""
____________________________________________
Discretization bottleneck part of the VQ-VAE.
Inputs:
- n_e : number of embeddings
- e_dim : dimension of embedding
- beta : commitment cost used in loss term,
beta * ||z_e(x)-sg[e]||^2
_____________________________________________
"""
def __init__(
self,
n_e: int,
e_dim: int,
beta: float = 0.25,
remap: Optional[str] = None,
unknown_index: str = "random",
sane_index_shape: bool = False,
log_perplexity: bool = False,
embedding_weight_norm: bool = False,
loss_key: str = "loss/vq",
):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.loss_key = loss_key
if not embedding_weight_norm:
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
else:
self.embedding = torch.nn.utils.weight_norm(
nn.Embedding(self.n_e, self.e_dim), dim=1
)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
else:
self.used = None
self.re_embed = n_e
if unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
else:
assert unknown_index == "random" or isinstance(
unknown_index, int
), "unknown index needs to be 'random', 'extra' or any integer"
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.remap is not None:
logpy.info(
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices."
)
self.sane_index_shape = sane_index_shape
self.log_perplexity = log_perplexity
def forward(
self,
z: torch.Tensor,
) -> Tuple[torch.Tensor, Dict]:
do_reshape = z.ndim == 4
if do_reshape:
# # reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, "b c h w -> b h w c").contiguous()
else:
assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined"
z = z.contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2
* torch.einsum(
"bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")
)
)
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
loss_dict = {}
if self.log_perplexity:
perplexity, cluster_usage = measure_perplexity(
min_encoding_indices.detach(), self.n_e
)
loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage})
# compute loss for embedding
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
(z_q - z.detach()) ** 2
)
loss_dict[self.loss_key] = loss
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
if do_reshape:
z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(
z.shape[0], -1
) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape:
if do_reshape:
min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3]
)
else:
min_encoding_indices = rearrange(
min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]
)
loss_dict["min_encoding_indices"] = min_encoding_indices
return z_q, loss_dict
def get_codebook_entry(
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
) -> torch.Tensor:
# shape specifying (batch, height, width, channel)
if self.remap is not None:
assert shape is not None, "Need to give shape for remap"
indices = indices.reshape(shape[0], -1) # add batch axis
indices = self.unmap_to_all(indices)
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
class EmbeddingEMA(nn.Module):
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
super().__init__()
self.decay = decay
self.eps = eps
weight = torch.randn(num_tokens, codebook_dim)
self.weight = nn.Parameter(weight, requires_grad=False)
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
self.update = True
def forward(self, embed_id):
return F.embedding(embed_id, self.weight)
def cluster_size_ema_update(self, new_cluster_size):
self.cluster_size.data.mul_(self.decay).add_(
new_cluster_size, alpha=1 - self.decay
)
def embed_avg_ema_update(self, new_embed_avg):
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
def weight_update(self, num_tokens):
n = self.cluster_size.sum()
smoothed_cluster_size = (
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
)
# normalize embedding average with smoothed cluster size
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
self.weight.data.copy_(embed_normalized)
class EMAVectorQuantizer(AbstractQuantizer):
def __init__(
self,
n_embed: int,
embedding_dim: int,
beta: float,
decay: float = 0.99,
eps: float = 1e-5,
remap: Optional[str] = None,
unknown_index: str = "random",
loss_key: str = "loss/vq",
):
super().__init__()
self.codebook_dim = embedding_dim
self.num_tokens = n_embed
self.beta = beta
self.loss_key = loss_key
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
else:
self.used = None
self.re_embed = n_embed
if unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
else:
assert unknown_index == "random" or isinstance(
unknown_index, int
), "unknown index needs to be 'random', 'extra' or any integer"
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.remap is not None:
logpy.info(
f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices."
)
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
# reshape z -> (batch, height, width, channel) and flatten
# z, 'b c h w -> b h w c'
z = rearrange(z, "b c h w -> b h w c")
z_flattened = z.reshape(-1, self.codebook_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
z_flattened.pow(2).sum(dim=1, keepdim=True)
+ self.embedding.weight.pow(2).sum(dim=1)
- 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight)
) # 'n d -> d n'
encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(encoding_indices).view(z.shape)
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
if self.training and self.embedding.update:
# EMA cluster size
encodings_sum = encodings.sum(0)
self.embedding.cluster_size_ema_update(encodings_sum)
# EMA embedding average
embed_sum = encodings.transpose(0, 1) @ z_flattened
self.embedding.embed_avg_ema_update(embed_sum)
# normalize embed_avg and update weight
self.embedding.weight_update(self.num_tokens)
# compute loss for embedding
loss = self.beta * F.mse_loss(z_q.detach(), z)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
# z_q, 'b h w c -> b c h w'
z_q = rearrange(z_q, "b h w c -> b c h w")
out_dict = {
self.loss_key: loss,
"encodings": encodings,
"encoding_indices": encoding_indices,
"perplexity": perplexity,
}
return z_q, out_dict
class VectorQuantizerWithInputProjection(VectorQuantizer):
def __init__(
self,
input_dim: int,
n_codes: int,
codebook_dim: int,
beta: float = 1.0,
output_dim: Optional[int] = None,
**kwargs,
):
super().__init__(n_codes, codebook_dim, beta, **kwargs)
self.proj_in = nn.Linear(input_dim, codebook_dim)
self.output_dim = output_dim
if output_dim is not None:
self.proj_out = nn.Linear(codebook_dim, output_dim)
else:
self.proj_out = nn.Identity()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
rearr = False
in_shape = z.shape
if z.ndim > 3:
rearr = self.output_dim is not None
z = rearrange(z, "b c ... -> b (...) c")
z = self.proj_in(z)
z_q, loss_dict = super().forward(z)
z_q = self.proj_out(z_q)
if rearr:
if len(in_shape) == 4:
z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1])
elif len(in_shape) == 5:
z_q = rearrange(
z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2]
)
else:
raise NotImplementedError(
f"rearranging not available for {len(in_shape)}-dimensional input."
)
return z_q, loss_dict

View File

@@ -0,0 +1,349 @@
from typing import Callable, Iterable, Union
import torch
from einops import rearrange, repeat
from sgm.modules.diffusionmodules.model import (
XFORMERS_IS_AVAILABLE,
AttnBlock,
Decoder,
MemoryEfficientAttnBlock,
ResnetBlock,
)
from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding
from sgm.modules.video_attention import VideoTransformerBlock
from sgm.util import partialclass
class VideoResBlock(ResnetBlock):
def __init__(
self,
out_channels,
*args,
dropout=0.0,
video_kernel_size=3,
alpha=0.0,
merge_strategy="learned",
**kwargs,
):
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
if video_kernel_size is None:
video_kernel_size = [3, 1, 1]
self.time_stack = ResBlock(
channels=out_channels,
emb_channels=0,
dropout=dropout,
dims=3,
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=False,
skip_t_emb=True,
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, bs):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError()
def forward(self, x, temb, skip_video=False, timesteps=None):
if timesteps is None:
timesteps = self.timesteps
b, c, h, w = x.shape
x = super().forward(x, temb)
if not skip_video:
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_stack(x, temb)
alpha = self.get_alpha(bs=b // timesteps)
x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class AE3DConv(torch.nn.Conv2d):
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable):
padding = [int(k // 2) for k in video_kernel_size]
else:
padding = int(video_kernel_size // 2)
self.time_mix_conv = torch.nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=video_kernel_size,
padding=padding,
)
def forward(self, input, timesteps, skip_video=False):
x = super().forward(input)
if skip_video:
return x
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_mix_conv(x)
return rearrange(x, "b c t h w -> (b t) c h w")
class VideoBlock(AttnBlock):
def __init__(
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
):
super().__init__(in_channels)
# no context, single headed, as in base class
self.time_mix_block = VideoTransformerBlock(
dim=in_channels,
n_heads=1,
d_head=in_channels,
checkpoint=False,
ff_in=True,
attn_mode="softmax",
)
time_embed_dim = self.in_channels * 4
self.video_time_embed = torch.nn.Sequential(
torch.nn.Linear(self.in_channels, time_embed_dim),
torch.nn.SiLU(),
torch.nn.Linear(time_embed_dim, self.in_channels),
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def forward(self, x, timesteps, skip_video=False):
if skip_video:
return super().forward(x)
x_in = x
x = self.attention(x)
h, w = x.shape[2:]
x = rearrange(x, "b c h w -> b (h w) c")
x_mix = x
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
emb = self.video_time_embed(t_emb) # b, n_channels
emb = emb[:, None, :]
x_mix = x_mix + emb
alpha = self.get_alpha()
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x)
return x_in + x
def get_alpha(
self,
):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
def __init__(
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
):
super().__init__(in_channels)
# no context, single headed, as in base class
self.time_mix_block = VideoTransformerBlock(
dim=in_channels,
n_heads=1,
d_head=in_channels,
checkpoint=False,
ff_in=True,
attn_mode="softmax-xformers",
)
time_embed_dim = self.in_channels * 4
self.video_time_embed = torch.nn.Sequential(
torch.nn.Linear(self.in_channels, time_embed_dim),
torch.nn.SiLU(),
torch.nn.Linear(time_embed_dim, self.in_channels),
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def forward(self, x, timesteps, skip_time_block=False):
if skip_time_block:
return super().forward(x)
x_in = x
x = self.attention(x)
h, w = x.shape[2:]
x = rearrange(x, "b c h w -> b (h w) c")
x_mix = x
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
emb = self.video_time_embed(t_emb) # b, n_channels
emb = emb[:, None, :]
x_mix = x_mix + emb
alpha = self.get_alpha()
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x)
return x_in + x
def get_alpha(
self,
):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
def make_time_attn(
in_channels,
attn_type="vanilla",
attn_kwargs=None,
alpha: float = 0,
merge_strategy: str = "learned",
):
assert attn_type in [
"vanilla",
"vanilla-xformers",
], f"attn_type {attn_type} not supported for spatio-temporal attention"
print(
f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
)
if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
print(
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
)
attn_type = "vanilla"
if attn_type == "vanilla":
assert attn_kwargs is None
return partialclass(
VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
)
elif attn_type == "vanilla-xformers":
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return partialclass(
MemoryEfficientVideoBlock,
in_channels,
alpha=alpha,
merge_strategy=merge_strategy,
)
else:
return NotImplementedError()
class Conv2DWrapper(torch.nn.Conv2d):
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
return super().forward(input)
class VideoDecoder(Decoder):
available_time_modes = ["all", "conv-only", "attn-only"]
def __init__(
self,
*args,
video_kernel_size: Union[int, list] = 3,
alpha: float = 0.0,
merge_strategy: str = "learned",
time_mode: str = "conv-only",
**kwargs,
):
self.video_kernel_size = video_kernel_size
self.alpha = alpha
self.merge_strategy = merge_strategy
self.time_mode = time_mode
assert (
self.time_mode in self.available_time_modes
), f"time_mode parameter has to be in {self.available_time_modes}"
super().__init__(*args, **kwargs)
def get_last_layer(self, skip_time_mix=False, **kwargs):
if self.time_mode == "attn-only":
raise NotImplementedError("TODO")
else:
return (
self.conv_out.time_mix_conv.weight
if not skip_time_mix
else self.conv_out.weight
)
def _make_attn(self) -> Callable:
if self.time_mode not in ["conv-only", "only-last-conv"]:
return partialclass(
make_time_attn,
alpha=self.alpha,
merge_strategy=self.merge_strategy,
)
else:
return super()._make_attn()
def _make_conv(self) -> Callable:
if self.time_mode != "attn-only":
return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
else:
return Conv2DWrapper
def _make_resblock(self) -> Callable:
if self.time_mode not in ["attn-only", "only-last-conv"]:
return partialclass(
VideoResBlock,
video_kernel_size=self.video_kernel_size,
alpha=self.alpha,
merge_strategy=self.merge_strategy,
)
else:
return super()._make_resblock()

View File

@@ -1,7 +0,0 @@
from .denoiser import Denoiser
from .discretizer import Discretization
from .loss import StandardDiffusionLoss
from .model import Decoder, Encoder, Model
from .openaimodel import UNetModel
from .sampling import BaseDiffusionSampler
from .wrappers import OpenAIWrapper

View File

@@ -1,62 +1,74 @@
from typing import Dict, Union
import torch
import torch.nn as nn import torch.nn as nn
from ...util import append_dims, instantiate_from_config from ...util import append_dims, instantiate_from_config
from .denoiser_scaling import DenoiserScaling
from .discretizer import Discretization
class Denoiser(nn.Module): class Denoiser(nn.Module):
def __init__(self, weighting_config, scaling_config): def __init__(self, scaling_config: Dict):
super().__init__() super().__init__()
self.weighting = instantiate_from_config(weighting_config) self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)
self.scaling = instantiate_from_config(scaling_config)
def possibly_quantize_sigma(self, sigma): def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
return sigma return sigma
def possibly_quantize_c_noise(self, c_noise): def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
return c_noise return c_noise
def w(self, sigma): def forward(
return self.weighting(sigma) self,
network: nn.Module,
def __call__(self, network, input, sigma, cond): input: torch.Tensor,
sigma: torch.Tensor,
cond: Dict,
**additional_model_inputs,
) -> torch.Tensor:
sigma = self.possibly_quantize_sigma(sigma) sigma = self.possibly_quantize_sigma(sigma)
sigma_shape = sigma.shape sigma_shape = sigma.shape
sigma = append_dims(sigma, input.ndim) sigma = append_dims(sigma, input.ndim)
c_skip, c_out, c_in, c_noise = self.scaling(sigma) c_skip, c_out, c_in, c_noise = self.scaling(sigma)
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
return network(input * c_in, c_noise, cond) * c_out + input * c_skip return (
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
+ input * c_skip
)
class DiscreteDenoiser(Denoiser): class DiscreteDenoiser(Denoiser):
def __init__( def __init__(
self, self,
weighting_config, scaling_config: Dict,
scaling_config, num_idx: int,
num_idx, discretization_config: Dict,
discretization_config, do_append_zero: bool = False,
do_append_zero=False, quantize_c_noise: bool = True,
quantize_c_noise=True, flip: bool = True,
flip=True,
): ):
super().__init__(weighting_config, scaling_config) super().__init__(scaling_config)
sigmas = instantiate_from_config(discretization_config)( self.discretization: Discretization = instantiate_from_config(
num_idx, do_append_zero=do_append_zero, flip=flip discretization_config
) )
sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip)
self.register_buffer("sigmas", sigmas) self.register_buffer("sigmas", sigmas)
self.quantize_c_noise = quantize_c_noise self.quantize_c_noise = quantize_c_noise
self.num_idx = num_idx
def sigma_to_idx(self, sigma): def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor:
dists = sigma - self.sigmas[:, None] dists = sigma - self.sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape) return dists.abs().argmin(dim=0).view(sigma.shape)
def idx_to_sigma(self, idx): def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor:
return self.sigmas[idx] return self.sigmas[idx]
def possibly_quantize_sigma(self, sigma): def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
return self.idx_to_sigma(self.sigma_to_idx(sigma)) return self.idx_to_sigma(self.sigma_to_idx(sigma))
def possibly_quantize_c_noise(self, c_noise): def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
if self.quantize_c_noise: if self.quantize_c_noise:
return self.sigma_to_idx(c_noise) return self.sigma_to_idx(c_noise)
else: else:

View File

@@ -1,11 +1,24 @@
from abc import ABC, abstractmethod
from typing import Tuple
import torch import torch
class DenoiserScaling(ABC):
@abstractmethod
def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
pass
class EDMScaling: class EDMScaling:
def __init__(self, sigma_data=0.5): def __init__(self, sigma_data: float = 0.5):
self.sigma_data = sigma_data self.sigma_data = sigma_data
def __call__(self, sigma): def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
@@ -14,7 +27,9 @@ class EDMScaling:
class EpsScaling: class EpsScaling:
def __call__(self, sigma): def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = torch.ones_like(sigma, device=sigma.device) c_skip = torch.ones_like(sigma, device=sigma.device)
c_out = -sigma c_out = -sigma
c_in = 1 / (sigma**2 + 1.0) ** 0.5 c_in = 1 / (sigma**2 + 1.0) ** 0.5
@@ -23,9 +38,22 @@ class EpsScaling:
class VScaling: class VScaling:
def __call__(self, sigma): def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = 1.0 / (sigma**2 + 1.0) c_skip = 1.0 / (sigma**2 + 1.0)
c_out = -sigma / (sigma**2 + 1.0) ** 0.5 c_out = -sigma / (sigma**2 + 1.0) ** 0.5
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
c_noise = sigma.clone() c_noise = sigma.clone()
return c_skip, c_out, c_in, c_noise return c_skip, c_out, c_in, c_noise
class VScalingWithEDMcNoise(DenoiserScaling):
def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = 1.0 / (sigma**2 + 1.0)
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
c_noise = 0.25 * sigma.log()
return c_skip, c_out, c_in, c_noise

View File

@@ -1,31 +1,33 @@
from functools import partial import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from einops import rearrange, repeat
from ...util import default, instantiate_from_config from ...util import append_dims, default
logpy = logging.getLogger(__name__)
class VanillaCFG: class Guider(ABC):
""" @abstractmethod
implements parallelized CFG def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
""" pass
def __init__(self, scale, dyn_thresh_config=None): def prepare_inputs(
scale_schedule = lambda scale, sigma: scale # independent of step self, x: torch.Tensor, s: float, c: Dict, uc: Dict
self.scale_schedule = partial(scale_schedule, scale) ) -> Tuple[torch.Tensor, float, Dict]:
self.dyn_thresh = instantiate_from_config( pass
default(
dyn_thresh_config,
{
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
},
)
)
def __call__(self, x, sigma):
class VanillaCFG(Guider):
def __init__(self, scale: float):
self.scale = scale
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2) x_u, x_c = x.chunk(2)
scale_value = self.scale_schedule(sigma) x_pred = x_u + self.scale * (x_c - x_u)
x_pred = self.dyn_thresh(x_u, x_c, scale_value)
return x_pred return x_pred
def prepare_inputs(self, x, s, c, uc): def prepare_inputs(self, x, s, c, uc):
@@ -40,14 +42,58 @@ class VanillaCFG:
return torch.cat([x] * 2), torch.cat([s] * 2), c_out return torch.cat([x] * 2), torch.cat([s] * 2), c_out
class IdentityGuider: class IdentityGuider(Guider):
def __call__(self, x, sigma): def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
return x return x
def prepare_inputs(self, x, s, c, uc): def prepare_inputs(
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
) -> Tuple[torch.Tensor, float, Dict]:
c_out = dict() c_out = dict()
for k in c: for k in c:
c_out[k] = c[k] c_out[k] = c[k]
return x, s, c_out return x, s, c_out
class LinearPredictionGuider(Guider):
def __init__(
self,
max_scale: float,
num_frames: int,
min_scale: float = 1.0,
additional_cond_keys: Optional[Union[List[str], str]] = None,
):
self.min_scale = min_scale
self.max_scale = max_scale
self.num_frames = num_frames
self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
additional_cond_keys = default(additional_cond_keys, [])
if isinstance(additional_cond_keys, str):
additional_cond_keys = [additional_cond_keys]
self.additional_cond_keys = additional_cond_keys
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
scale = append_dims(scale, x_u.ndim).to(x_u.device)
return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
def prepare_inputs(
self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
c_out = dict()
for k in c:
if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out

View File

@@ -1,31 +1,34 @@
from typing import List, Optional, Union from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from omegaconf import ListConfig
from ...util import append_dims, instantiate_from_config
from ...modules.autoencoding.lpips.loss.lpips import LPIPS from ...modules.autoencoding.lpips.loss.lpips import LPIPS
from ...modules.encoders.modules import GeneralConditioner
from ...util import append_dims, instantiate_from_config
from .denoiser import Denoiser
class StandardDiffusionLoss(nn.Module): class StandardDiffusionLoss(nn.Module):
def __init__( def __init__(
self, self,
sigma_sampler_config, sigma_sampler_config: dict,
type="l2", loss_weighting_config: dict,
offset_noise_level=0.0, loss_type: str = "l2",
batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None, offset_noise_level: float = 0.0,
batch2model_keys: Optional[Union[str, List[str]]] = None,
): ):
super().__init__() super().__init__()
assert type in ["l2", "l1", "lpips"] assert loss_type in ["l2", "l1", "lpips"]
self.sigma_sampler = instantiate_from_config(sigma_sampler_config) self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
self.loss_weighting = instantiate_from_config(loss_weighting_config)
self.type = type self.loss_type = loss_type
self.offset_noise_level = offset_noise_level self.offset_noise_level = offset_noise_level
if type == "lpips": if loss_type == "lpips":
self.lpips = LPIPS().eval() self.lpips = LPIPS().eval()
if not batch2model_keys: if not batch2model_keys:
@@ -36,34 +39,67 @@ class StandardDiffusionLoss(nn.Module):
self.batch2model_keys = set(batch2model_keys) self.batch2model_keys = set(batch2model_keys)
def __call__(self, network, denoiser, conditioner, input, batch): def get_noised_input(
self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor
) -> torch.Tensor:
noised_input = input + noise * sigmas_bc
return noised_input
def forward(
self,
network: nn.Module,
denoiser: Denoiser,
conditioner: GeneralConditioner,
input: torch.Tensor,
batch: Dict,
) -> torch.Tensor:
cond = conditioner(batch) cond = conditioner(batch)
return self._forward(network, denoiser, cond, input, batch)
def _forward(
self,
network: nn.Module,
denoiser: Denoiser,
cond: Dict,
input: torch.Tensor,
batch: Dict,
) -> Tuple[torch.Tensor, Dict]:
additional_model_inputs = { additional_model_inputs = {
key: batch[key] for key in self.batch2model_keys.intersection(batch) key: batch[key] for key in self.batch2model_keys.intersection(batch)
} }
sigmas = self.sigma_sampler(input.shape[0]).to(input)
sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
noise = torch.randn_like(input) noise = torch.randn_like(input)
if self.offset_noise_level > 0.0: if self.offset_noise_level > 0.0:
noise = noise + self.offset_noise_level * append_dims( offset_shape = (
torch.randn(input.shape[0], device=input.device), input.ndim (input.shape[0], 1, input.shape[2])
if self.n_frames is not None
else (input.shape[0], input.shape[1])
) )
noised_input = input + noise * append_dims(sigmas, input.ndim) noise = noise + self.offset_noise_level * append_dims(
torch.randn(offset_shape, device=input.device),
input.ndim,
)
sigmas_bc = append_dims(sigmas, input.ndim)
noised_input = self.get_noised_input(sigmas_bc, noise, input)
model_output = denoiser( model_output = denoiser(
network, noised_input, sigmas, cond, **additional_model_inputs network, noised_input, sigmas, cond, **additional_model_inputs
) )
w = append_dims(denoiser.w(sigmas), input.ndim) w = append_dims(self.loss_weighting(sigmas), input.ndim)
return self.get_loss(model_output, input, w) return self.get_loss(model_output, input, w)
def get_loss(self, model_output, target, w): def get_loss(self, model_output, target, w):
if self.type == "l2": if self.loss_type == "l2":
return torch.mean( return torch.mean(
(w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
) )
elif self.type == "l1": elif self.loss_type == "l1":
return torch.mean( return torch.mean(
(w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
) )
elif self.type == "lpips": elif self.loss_type == "lpips":
loss = self.lpips(model_output, target).reshape(-1) loss = self.lpips(model_output, target).reshape(-1)
return loss return loss
else:
raise NotImplementedError(f"Unknown loss type {self.loss_type}")

View File

@@ -0,0 +1,32 @@
from abc import ABC, abstractmethod
import torch
class DiffusionLossWeighting(ABC):
@abstractmethod
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
pass
class UnitWeighting(DiffusionLossWeighting):
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
return torch.ones_like(sigma, device=sigma.device)
class EDMWeighting(DiffusionLossWeighting):
def __init__(self, sigma_data: float = 0.5):
self.sigma_data = sigma_data
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
class VWeighting(EDMWeighting):
def __init__(self):
super().__init__(sigma_data=1.0)
class EpsWeighting(DiffusionLossWeighting):
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
return sigma**-2.0

View File

@@ -1,4 +1,5 @@
# pytorch_diffusion + derived encoder decoder # pytorch_diffusion + derived encoder decoder
import logging
import math import math
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
@@ -8,6 +9,8 @@ import torch.nn as nn
from einops import rearrange from einops import rearrange
from packaging import version from packaging import version
logpy = logging.getLogger(__name__)
try: try:
import xformers import xformers
import xformers.ops import xformers.ops
@@ -15,7 +18,7 @@ try:
XFORMERS_IS_AVAILABLE = True XFORMERS_IS_AVAILABLE = True
except: except:
XFORMERS_IS_AVAILABLE = False XFORMERS_IS_AVAILABLE = False
print("no module 'xformers'. Processing without...") logpy.warning("no module 'xformers'. Processing without...")
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
@@ -288,12 +291,14 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'" f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
) )
attn_type = "vanilla-xformers" attn_type = "vanilla-xformers"
print(f"making attention of type '{attn_type}' with {in_channels} in_channels") logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla": if attn_type == "vanilla":
assert attn_kwargs is None assert attn_kwargs is None
return AttnBlock(in_channels) return AttnBlock(in_channels)
elif attn_type == "vanilla-xformers": elif attn_type == "vanilla-xformers":
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") logpy.info(
f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
)
return MemoryEfficientAttnBlock(in_channels) return MemoryEfficientAttnBlock(in_channels)
elif type == "memory-efficient-cross-attn": elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels attn_kwargs["query_dim"] = in_channels
@@ -633,7 +638,7 @@ class Decoder(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1] block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1) curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res) self.z_shape = (1, z_channels, curr_res, curr_res)
print( logpy.info(
"Working with z of shape {} = {} dimensions.".format( "Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape) self.z_shape, np.prod(self.z_shape)
) )

File diff suppressed because it is too large Load Diff

View File

@@ -9,13 +9,10 @@ import torch
from omegaconf import ListConfig, OmegaConf from omegaconf import ListConfig, OmegaConf
from tqdm import tqdm from tqdm import tqdm
from ...modules.diffusionmodules.sampling_utils import ( from ...modules.diffusionmodules.sampling_utils import (get_ancestral_step,
get_ancestral_step,
linear_multistep_coeff, linear_multistep_coeff,
to_d, to_d, to_neg_log_sigma,
to_neg_log_sigma, to_sigma)
to_sigma,
)
from ...util import append_dims, default, instantiate_from_config from ...util import append_dims, default, instantiate_from_config
DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}

View File

@@ -4,11 +4,6 @@ from scipy import integrate
from ...util import append_dims from ...util import append_dims
class NoDynamicThresholding:
def __call__(self, uncond, cond, scale):
return uncond + scale * (cond - uncond)
def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
if order - 1 > i: if order - 1 > i:
raise ValueError(f"Order {order} too high for step {i}") raise ValueError(f"Order {order} too high for step {i}")

View File

@@ -1,5 +1,5 @@
""" """
adopted from partially adopted from
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
and and
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
@@ -10,10 +10,11 @@ thanks!
""" """
import math import math
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import repeat from einops import rearrange, repeat
def make_beta_schedule( def make_beta_schedule(
@@ -306,3 +307,63 @@ def avg_pool_nd(dims, *args, **kwargs):
elif dims == 3: elif dims == 3:
return nn.AvgPool3d(*args, **kwargs) return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f"unsupported dimensions: {dims}")
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
def __init__(
self,
alpha: float,
merge_strategy: str = "learned_with_images",
rearrange_pattern: str = "b t -> (b t) 1 1",
):
super().__init__()
self.merge_strategy = merge_strategy
self.rearrange_pattern = rearrange_pattern
assert (
merge_strategy in self.strategies
), f"merge_strategy needs to be in {self.strategies}"
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif (
self.merge_strategy == "learned"
or self.merge_strategy == "learned_with_images"
):
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
if self.merge_strategy == "fixed":
alpha = self.mix_factor
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
elif self.merge_strategy == "learned_with_images":
assert image_only_indicator is not None, "need image_only_indicator ..."
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
)
alpha = rearrange(alpha, self.rearrange_pattern)
else:
raise NotImplementedError
return alpha
def forward(
self,
x_spatial: torch.Tensor,
x_temporal: torch.Tensor,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator)
x = (
alpha.to(x_spatial.dtype) * x_spatial
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
)
return x

View File

@@ -0,0 +1,493 @@
from functools import partial
from typing import List, Optional, Union
from einops import rearrange
from ...modules.diffusionmodules.openaimodel import *
from ...modules.video_attention import SpatialVideoTransformer
from ...util import default
from .util import AlphaBlender
class VideoResBlock(ResBlock):
def __init__(
self,
channels: int,
emb_channels: int,
dropout: float,
video_kernel_size: Union[int, List[int]] = 3,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
out_channels: Optional[int] = None,
use_conv: bool = False,
use_scale_shift_norm: bool = False,
dims: int = 2,
use_checkpoint: bool = False,
up: bool = False,
down: bool = False,
):
super().__init__(
channels,
emb_channels,
dropout,
out_channels=out_channels,
use_conv=use_conv,
use_scale_shift_norm=use_scale_shift_norm,
dims=dims,
use_checkpoint=use_checkpoint,
up=up,
down=down,
)
self.time_stack = ResBlock(
default(out_channels, channels),
emb_channels,
dropout=dropout,
dims=3,
out_channels=default(out_channels, channels),
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=use_checkpoint,
exchange_temb_dims=True,
)
self.time_mixer = AlphaBlender(
alpha=merge_factor,
merge_strategy=merge_strategy,
rearrange_pattern="b t -> b 1 t 1 1",
)
def forward(
self,
x: th.Tensor,
emb: th.Tensor,
num_video_frames: int,
image_only_indicator: Optional[th.Tensor] = None,
) -> th.Tensor:
x = super().forward(x, emb)
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = self.time_stack(
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
)
x = self.time_mixer(
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
)
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class VideoUNet(nn.Module):
def __init__(
self,
in_channels: int,
model_channels: int,
out_channels: int,
num_res_blocks: int,
attention_resolutions: int,
dropout: float = 0.0,
channel_mult: List[int] = (1, 2, 4, 8),
conv_resample: bool = True,
dims: int = 2,
num_classes: Optional[int] = None,
use_checkpoint: bool = False,
num_heads: int = -1,
num_head_channels: int = -1,
num_heads_upsample: int = -1,
use_scale_shift_norm: bool = False,
resblock_updown: bool = False,
transformer_depth: Union[List[int], int] = 1,
transformer_depth_middle: Optional[int] = None,
context_dim: Optional[int] = None,
time_downup: bool = False,
time_context_dim: Optional[int] = None,
extra_ff_mix_layer: bool = False,
use_spatial_context: bool = False,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
spatial_transformer_attn_type: str = "softmax",
video_kernel_size: Union[int, List[int]] = 3,
use_linear_in_transformer: bool = False,
adm_in_channels: Optional[int] = None,
disable_temporal_crossattention: bool = False,
max_ddpm_temb_period: int = 10000,
):
super().__init__()
assert context_dim is not None
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1
if num_head_channels == -1:
assert num_heads != -1
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
transformer_depth_middle = default(
transformer_depth_middle, transformer_depth[-1]
)
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "timestep":
self.label_emb = nn.Sequential(
Timestep(model_channels),
nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
),
)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
else:
raise ValueError()
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
def get_attention_layer(
ch,
num_heads,
dim_head,
depth=1,
context_dim=None,
use_checkpoint=False,
disabled_sa=False,
):
return SpatialVideoTransformer(
ch,
num_heads,
dim_head,
depth=depth,
context_dim=context_dim,
time_context_dim=time_context_dim,
dropout=dropout,
ff_in=extra_ff_mix_layer,
use_spatial_context=use_spatial_context,
merge_strategy=merge_strategy,
merge_factor=merge_factor,
checkpoint=use_checkpoint,
use_linear=use_linear_in_transformer,
attn_mode=spatial_transformer_attn_type,
disable_self_attn=disabled_sa,
disable_temporal_crossattention=disable_temporal_crossattention,
max_time_embed_period=max_ddpm_temb_period,
)
def get_resblock(
merge_factor,
merge_strategy,
video_kernel_size,
ch,
time_embed_dim,
dropout,
out_ch,
dims,
use_checkpoint,
use_scale_shift_norm,
down=False,
up=False,
):
return VideoResBlock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
channels=ch,
emb_channels=time_embed_dim,
dropout=dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=down,
up=up,
)
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_ch=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers.append(
get_attention_layer(
ch,
num_heads,
dim_head,
depth=transformer_depth[level],
context_dim=context_dim,
use_checkpoint=use_checkpoint,
disabled_sa=False,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
ds *= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_ch=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch,
conv_resample,
dims=dims,
out_channels=out_ch,
third_down=time_downup,
)
)
)
ch = out_ch
input_block_chans.append(ch)
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
self.middle_block = TimestepEmbedSequential(
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
out_ch=None,
dropout=dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
get_attention_layer(
ch,
num_heads,
dim_head,
depth=transformer_depth_middle,
context_dim=context_dim,
use_checkpoint=use_checkpoint,
),
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
out_ch=None,
time_embed_dim=time_embed_dim,
dropout=dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch + ich,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_ch=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers.append(
get_attention_layer(
ch,
num_heads,
dim_head,
depth=transformer_depth[level],
context_dim=context_dim,
use_checkpoint=use_checkpoint,
disabled_sa=False,
)
)
if level and i == num_res_blocks:
out_ch = ch
ds //= 2
layers.append(
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_ch=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
if resblock_updown
else Upsample(
ch,
conv_resample,
dims=dims,
out_channels=out_ch,
third_up=time_downup,
)
)
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
def forward(
self,
x: th.Tensor,
timesteps: th.Tensor,
context: Optional[th.Tensor] = None,
y: Optional[th.Tensor] = None,
time_context: Optional[th.Tensor] = None,
num_video_frames: Optional[int] = None,
image_only_indicator: Optional[th.Tensor] = None,
):
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x
for module in self.input_blocks:
h = module(
h,
emb,
context=context,
image_only_indicator=image_only_indicator,
time_context=time_context,
num_video_frames=num_video_frames,
)
hs.append(h)
h = self.middle_block(
h,
emb,
context=context,
image_only_indicator=image_only_indicator,
time_context=time_context,
num_video_frames=num_video_frames,
)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(
h,
emb,
context=context,
image_only_indicator=image_only_indicator,
time_context=time_context,
num_video_frames=num_video_frames,
)
h = h.type(x.dtype)
return self.out(h)

View File

@@ -1,3 +1,4 @@
import math
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial from functools import partial
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
@@ -10,27 +11,17 @@ import torch.nn as nn
from einops import rearrange, repeat from einops import rearrange, repeat
from omegaconf import ListConfig from omegaconf import ListConfig
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from transformers import ( from transformers import (ByT5Tokenizer, CLIPTextModel, CLIPTokenizer,
ByT5Tokenizer, T5EncoderModel, T5Tokenizer)
CLIPTextModel,
CLIPTokenizer,
T5EncoderModel,
T5Tokenizer,
)
from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
from ...modules.diffusionmodules.model import Encoder from ...modules.diffusionmodules.model import Encoder
from ...modules.diffusionmodules.openaimodel import Timestep from ...modules.diffusionmodules.openaimodel import Timestep
from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule from ...modules.diffusionmodules.util import (extract_into_tensor,
make_beta_schedule)
from ...modules.distributions.distributions import DiagonalGaussianDistribution from ...modules.distributions.distributions import DiagonalGaussianDistribution
from ...util import ( from ...util import (append_dims, autocast, count_params, default,
autocast, disabled_train, expand_dims_like, instantiate_from_config)
count_params,
default,
disabled_train,
expand_dims_like,
instantiate_from_config,
)
class AbstractEmbModel(nn.Module): class AbstractEmbModel(nn.Module):
@@ -173,7 +164,11 @@ class GeneralConditioner(nn.Module):
return output return output
def get_unconditional_conditioning( def get_unconditional_conditioning(
self, batch_c, batch_uc=None, force_uc_zero_embeddings=None self,
batch_c: Dict,
batch_uc: Optional[Dict] = None,
force_uc_zero_embeddings: Optional[List[str]] = None,
force_cond_zero_embeddings: Optional[List[str]] = None,
): ):
if force_uc_zero_embeddings is None: if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = [] force_uc_zero_embeddings = []
@@ -181,7 +176,7 @@ class GeneralConditioner(nn.Module):
for embedder in self.embedders: for embedder in self.embedders:
ucg_rates.append(embedder.ucg_rate) ucg_rates.append(embedder.ucg_rate)
embedder.ucg_rate = 0.0 embedder.ucg_rate = 0.0
c = self(batch_c) c = self(batch_c, force_cond_zero_embeddings)
uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
for embedder, rate in zip(self.embedders, ucg_rates): for embedder, rate in zip(self.embedders, ucg_rates):
@@ -201,12 +196,6 @@ class InceptionV3(nn.Module):
self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs) self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)
def forward(self, inp): def forward(self, inp):
# inp = kornia.geometry.resize(inp, (299, 299),
# interpolation='bicubic',
# align_corners=False,
# antialias=True)
# inp = inp.clamp(min=-1, max=1)
outp = self.model(inp) outp = self.model(inp)
if len(outp) == 1: if len(outp) == 1:
@@ -277,7 +266,6 @@ class FrozenT5Embedder(AbstractEmbModel):
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
# @autocast
def forward(self, text): def forward(self, text):
batch_encoding = self.tokenizer( batch_encoding = self.tokenizer(
text, text,
@@ -597,11 +585,12 @@ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
repeat_to_max_len=False, repeat_to_max_len=False,
num_image_crops=0, num_image_crops=0,
output_tokens=False, output_tokens=False,
init_device=None,
): ):
super().__init__() super().__init__()
model, _, _ = open_clip.create_model_and_transforms( model, _, _ = open_clip.create_model_and_transforms(
arch, arch,
device=torch.device("cpu"), device=torch.device(default(init_device, "cpu")),
pretrained=version, pretrained=version,
) )
del model.transformer del model.transformer
@@ -914,7 +903,6 @@ class LowScaleEncoder(nn.Module):
z = self.q_sample(z, noise_level) z = self.q_sample(z, noise_level)
if self.out_size is not None: if self.out_size is not None:
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
return z, noise_level return z, noise_level
def decode(self, z): def decode(self, z):
@@ -958,3 +946,101 @@ class GaussianEncoder(Encoder, AbstractEmbModel):
if self.flatten_output: if self.flatten_output:
z = rearrange(z, "b c h w -> b (h w ) c") z = rearrange(z, "b c h w -> b (h w ) c")
return log, z return log, z
class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
def __init__(
self,
n_cond_frames: int,
n_copies: int,
encoder_config: dict,
sigma_sampler_config: Optional[dict] = None,
sigma_cond_config: Optional[dict] = None,
is_ae: bool = False,
scale_factor: float = 1.0,
disable_encoder_autocast: bool = False,
en_and_decode_n_samples_a_time: Optional[int] = None,
):
super().__init__()
self.n_cond_frames = n_cond_frames
self.n_copies = n_copies
self.encoder = instantiate_from_config(encoder_config)
self.sigma_sampler = (
instantiate_from_config(sigma_sampler_config)
if sigma_sampler_config is not None
else None
)
self.sigma_cond = (
instantiate_from_config(sigma_cond_config)
if sigma_cond_config is not None
else None
)
self.is_ae = is_ae
self.scale_factor = scale_factor
self.disable_encoder_autocast = disable_encoder_autocast
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
def forward(
self, vid: torch.Tensor
) -> Union[
torch.Tensor,
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, dict],
Tuple[Tuple[torch.Tensor, torch.Tensor], dict],
]:
if self.sigma_sampler is not None:
b = vid.shape[0] // self.n_cond_frames
sigmas = self.sigma_sampler(b).to(vid.device)
if self.sigma_cond is not None:
sigma_cond = self.sigma_cond(sigmas)
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies)
sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames)
noise = torch.randn_like(vid)
vid = vid + noise * append_dims(sigmas, vid.ndim)
with torch.autocast("cuda", enabled=not self.disable_encoder_autocast):
n_samples = (
self.en_and_decode_n_samples_a_time
if self.en_and_decode_n_samples_a_time is not None
else vid.shape[0]
)
n_rounds = math.ceil(vid.shape[0] / n_samples)
all_out = []
for n in range(n_rounds):
if self.is_ae:
out = self.encoder.encode(vid[n * n_samples : (n + 1) * n_samples])
else:
out = self.encoder(vid[n * n_samples : (n + 1) * n_samples])
all_out.append(out)
vid = torch.cat(all_out, dim=0)
vid *= self.scale_factor
vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames)
vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies)
return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid
return return_val
class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel):
def __init__(
self,
open_clip_embedding_config: Dict,
n_cond_frames: int,
n_copies: int,
):
super().__init__()
self.n_cond_frames = n_cond_frames
self.n_copies = n_copies
self.open_clip = instantiate_from_config(open_clip_embedding_config)
def forward(self, vid):
vid = self.open_clip(vid)
vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames)
vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies)
return vid

View File

@@ -0,0 +1,301 @@
import torch
from ..modules.attention import *
from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding
class TimeMixSequential(nn.Sequential):
def forward(self, x, context=None, timesteps=None):
for layer in self:
x = layer(x, context, timesteps)
return x
class VideoTransformerBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention,
"softmax-xformers": MemoryEfficientCrossAttention,
}
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
timesteps=None,
ff_in=False,
inner_dim=None,
attn_mode="softmax",
disable_self_attn=False,
disable_temporal_crossattention=False,
switch_temporal_ca_to_sa=False,
):
super().__init__()
attn_cls = self.ATTENTION_MODES[attn_mode]
self.ff_in = ff_in or inner_dim is not None
if inner_dim is None:
inner_dim = dim
assert int(n_heads * d_head) == inner_dim
self.is_res = inner_dim == dim
if self.ff_in:
self.norm_in = nn.LayerNorm(dim)
self.ff_in = FeedForward(
dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff
)
self.timesteps = timesteps
self.disable_self_attn = disable_self_attn
if self.disable_self_attn:
self.attn1 = attn_cls(
query_dim=inner_dim,
heads=n_heads,
dim_head=d_head,
context_dim=context_dim,
dropout=dropout,
) # is a cross-attention
else:
self.attn1 = attn_cls(
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)
if disable_temporal_crossattention:
if switch_temporal_ca_to_sa:
raise ValueError
else:
self.attn2 = None
else:
self.norm2 = nn.LayerNorm(inner_dim)
if switch_temporal_ca_to_sa:
self.attn2 = attn_cls(
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
else:
self.attn2 = attn_cls(
query_dim=inner_dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(inner_dim)
self.norm3 = nn.LayerNorm(inner_dim)
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
self.checkpoint = checkpoint
if self.checkpoint:
print(f"{self.__class__.__name__} is using checkpointing")
def forward(
self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None
) -> torch.Tensor:
if self.checkpoint:
return checkpoint(self._forward, x, context, timesteps)
else:
return self._forward(x, context, timesteps=timesteps)
def _forward(self, x, context=None, timesteps=None):
assert self.timesteps or timesteps
assert not (self.timesteps and timesteps) or self.timesteps == timesteps
timesteps = self.timesteps or timesteps
B, S, C = x.shape
x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
if self.ff_in:
x_skip = x
x = self.ff_in(self.norm_in(x))
if self.is_res:
x += x_skip
if self.disable_self_attn:
x = self.attn1(self.norm1(x), context=context) + x
else:
x = self.attn1(self.norm1(x)) + x
if self.attn2 is not None:
if self.switch_temporal_ca_to_sa:
x = self.attn2(self.norm2(x)) + x
else:
x = self.attn2(self.norm2(x), context=context) + x
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
x = rearrange(
x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)
return x
def get_last_layer(self):
return self.ff.net[-1].weight
class SpatialVideoTransformer(SpatialTransformer):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
use_linear=False,
context_dim=None,
use_spatial_context=False,
timesteps=None,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
time_context_dim=None,
ff_in=False,
checkpoint=False,
time_depth=1,
attn_mode="softmax",
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
):
super().__init__(
in_channels,
n_heads,
d_head,
depth=depth,
dropout=dropout,
attn_type=attn_mode,
use_checkpoint=checkpoint,
context_dim=context_dim,
use_linear=use_linear,
disable_self_attn=disable_self_attn,
)
self.time_depth = time_depth
self.depth = depth
self.max_time_embed_period = max_time_embed_period
time_mix_d_head = d_head
n_time_mix_heads = n_heads
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
inner_dim = n_heads * d_head
if use_spatial_context:
time_context_dim = context_dim
self.time_stack = nn.ModuleList(
[
VideoTransformerBlock(
inner_dim,
n_time_mix_heads,
time_mix_d_head,
dropout=dropout,
context_dim=time_context_dim,
timesteps=timesteps,
checkpoint=checkpoint,
ff_in=ff_in,
inner_dim=time_mix_inner_dim,
attn_mode=attn_mode,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
)
for _ in range(self.depth)
]
)
assert len(self.time_stack) == len(self.transformer_blocks)
self.use_spatial_context = use_spatial_context
self.in_channels = in_channels
time_embed_dim = self.in_channels * 4
self.time_pos_embed = nn.Sequential(
linear(self.in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, self.in_channels),
)
self.time_mixer = AlphaBlender(
alpha=merge_factor, merge_strategy=merge_strategy
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.Tensor:
_, _, h, w = x.shape
x_in = x
spatial_context = None
if exists(context):
spatial_context = context
if self.use_spatial_context:
assert (
context.ndim == 3
), f"n dims of spatial context should be 3 but are {context.ndim}"
time_context = context
time_context_first_timestep = time_context[::timesteps]
time_context = repeat(
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
)
elif time_context is not None and not self.use_spatial_context:
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
if time_context.ndim == 2:
time_context = rearrange(time_context, "b c -> b 1 c")
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c")
if self.use_linear:
x = self.proj_in(x)
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(
num_frames,
self.in_channels,
repeat_only=False,
max_period=self.max_time_embed_period,
)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]
for it_, (block, mix_block) in enumerate(
zip(self.transformer_blocks, self.time_stack)
):
x = block(
x,
context=spatial_context,
)
x_mix = x
x_mix = x_mix + emb
x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)
x = self.time_mixer(
x_spatial=x,
x_temporal=x_mix,
image_only_indicator=image_only_indicator,
)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
if not self.use_linear:
x = self.proj_out(x)
out = x + x_in
return out

View File

@@ -246,3 +246,30 @@ def get_configs_path() -> str:
if os.path.isdir(candidate): if os.path.isdir(candidate):
return candidate return candidate
raise FileNotFoundError(f"Could not find SGM configs in {candidates}") raise FileNotFoundError(f"Could not find SGM configs in {candidates}")
def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
"""
Will return the result of a recursive get attribute call.
E.g.:
a.b.c
= getattr(getattr(a, "b"), "c")
= get_nested_attribute(a, "b.c")
If any part of the attribute call is an integer x with current obj a, will
try to call a[x] instead of a.x first.
"""
attributes = attribute_path.split(".")
if depth is not None and depth > 0:
attributes = attributes[:depth]
assert len(attributes) > 0, "At least one attribute should be selected"
current_attribute = obj
current_key = None
for level, attribute in enumerate(attributes):
current_key = ".".join(attributes[: level + 1])
try:
id_ = int(attribute)
current_attribute = current_attribute[id_]
except ValueError:
current_attribute = getattr(current_attribute, attribute)
return (current_attribute, current_key) if return_key else current_attribute