mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 14:24:21 +01:00
Stable Video Diffusion
This commit is contained in:
119
README.md
119
README.md
@@ -4,26 +4,48 @@
|
||||
|
||||
## News
|
||||
|
||||
**November 21, 2023**
|
||||
|
||||
- We are releasing Stable Video Diffusion, an image-to-video model, for research purposes:
|
||||
- [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid): This model was trained to generate 14
|
||||
frames at resolution 576x1024 given a context frame of the same size.
|
||||
We use the standard image encoder from SD 2.1, but replace the decoder with a temporally-aware `deflickering decoder`.
|
||||
- [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt): Same architecture as `SVD` but finetuned
|
||||
for 25 frame generation.
|
||||
- We provide a streamlit demo `scripts/demo/video_sampling.py` and a standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models.
|
||||
- Alongside the model, we release a [technical report](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets).
|
||||
|
||||

|
||||
|
||||
**July 26, 2023**
|
||||
- We are releasing two new open models with a permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file hashes):
|
||||
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version over `SDXL-base-0.9`.
|
||||
- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version over `SDXL-refiner-0.9`.
|
||||
|
||||
- We are releasing two new open models with a
|
||||
permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file
|
||||
hashes):
|
||||
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version
|
||||
over `SDXL-base-0.9`.
|
||||
- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version
|
||||
over `SDXL-refiner-0.9`.
|
||||
|
||||

|
||||
|
||||
|
||||
**July 4, 2023**
|
||||
|
||||
- A technical report on SDXL is now available [here](https://arxiv.org/abs/2307.01952).
|
||||
|
||||
**June 22, 2023**
|
||||
|
||||
|
||||
- We are releasing two new diffusion models for research purposes:
|
||||
- `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip) and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses the OpenCLIP model.
|
||||
- `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
|
||||
- `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The
|
||||
base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip)
|
||||
and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses
|
||||
the OpenCLIP model.
|
||||
- `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is
|
||||
not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
|
||||
|
||||
If you would like to access these models for your research, please apply using one of the following links:
|
||||
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
|
||||
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
|
||||
and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
|
||||
This means that you can apply for any of the two links - and if you are granted - you can access both.
|
||||
Please log in to your Hugging Face Account with your organization email to request access.
|
||||
**We plan to do a full release soon (July).**
|
||||
@@ -32,21 +54,32 @@ Please log in to your Hugging Face Account with your organization email to reque
|
||||
|
||||
### General Philosophy
|
||||
|
||||
Modularity is king. This repo implements a config-driven approach where we build and combine submodules by calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
|
||||
Modularity is king. This repo implements a config-driven approach where we build and combine submodules by
|
||||
calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
|
||||
|
||||
### Changelog from the old `ldm` codebase
|
||||
|
||||
For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`, now `DiffusionEngine`) has been cleaned up:
|
||||
For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other
|
||||
training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`,
|
||||
now `DiffusionEngine`) has been cleaned up:
|
||||
|
||||
- No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial conditionings, and all combinations thereof) in a single class: `GeneralConditioner`, see `sgm/modules/encoders/modules.py`.
|
||||
- No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial
|
||||
conditionings, and all combinations thereof) in a single class: `GeneralConditioner`,
|
||||
see `sgm/modules/encoders/modules.py`.
|
||||
- We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
|
||||
samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model.
|
||||
- We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable change is probably now the option to train continuous time models):
|
||||
* Discrete times models (denoisers) are simply a special case of continuous time models (denoisers); see `sgm/modules/diffusionmodules/denoiser.py`.
|
||||
* The following features are now independent: weighting of the diffusion loss function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
|
||||
- We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable
|
||||
change is probably now the option to train continuous time models):
|
||||
* Discrete times models (denoisers) are simply a special case of continuous time models (denoisers);
|
||||
see `sgm/modules/diffusionmodules/denoiser.py`.
|
||||
* The following features are now independent: weighting of the diffusion loss
|
||||
function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the
|
||||
network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during
|
||||
training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
|
||||
- Autoencoding models have also been cleaned up.
|
||||
|
||||
## Installation:
|
||||
|
||||
<a name="installation"></a>
|
||||
|
||||
#### 1. Clone the repo
|
||||
@@ -60,21 +93,10 @@ cd generative-models
|
||||
|
||||
This is assuming you have navigated to the `generative-models` root after cloning it.
|
||||
|
||||
**NOTE:** This is tested under `python3.8` and `python3.10`. For other python versions, you might encounter version conflicts.
|
||||
|
||||
|
||||
**PyTorch 1.13**
|
||||
|
||||
```shell
|
||||
# install required packages from pypi
|
||||
python3 -m venv .pt13
|
||||
source .pt13/bin/activate
|
||||
pip3 install -r requirements/pt13.txt
|
||||
```
|
||||
**NOTE:** This is tested under `python3.10`. For other python versions, you might encounter version conflicts.
|
||||
|
||||
**PyTorch 2.0**
|
||||
|
||||
|
||||
```shell
|
||||
# install required packages from pypi
|
||||
python3 -m venv .pt2
|
||||
@@ -82,7 +104,6 @@ source .pt2/bin/activate
|
||||
pip3 install -r requirements/pt2.txt
|
||||
```
|
||||
|
||||
|
||||
#### 3. Install `sgm`
|
||||
|
||||
```shell
|
||||
@@ -114,8 +135,10 @@ depending on your use case and PyTorch version, manually.
|
||||
|
||||
## Inference
|
||||
|
||||
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`.
|
||||
We provide file hashes for the complete file as well as for only the saved tensors in the file (see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
|
||||
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling
|
||||
in `scripts/demo/sampling.py`.
|
||||
We provide file hashes for the complete file as well as for only the saved tensors in the file (
|
||||
see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
|
||||
The following models are currently supported:
|
||||
|
||||
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
@@ -136,19 +159,20 @@ The following models are currently supported:
|
||||
**Weights for SDXL**:
|
||||
|
||||
**SDXL-1.0:**
|
||||
The weights of SDXL-1.0 are available (subject to a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here:
|
||||
The weights of SDXL-1.0 are available (subject to
|
||||
a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here:
|
||||
|
||||
- base model: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/
|
||||
- refiner model: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/
|
||||
|
||||
|
||||
**SDXL-0.9:**
|
||||
The weights of SDXL-0.9 are available and subject to a [research license](model_licenses/LICENSE-SDXL0.9).
|
||||
If you would like to access these models for your research, please apply using one of the following links:
|
||||
[SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
|
||||
[SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
|
||||
and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
|
||||
This means that you can apply for any of the two links - and if you are granted - you can access both.
|
||||
Please log in to your Hugging Face Account with your organization email to request access.
|
||||
|
||||
|
||||
After obtaining the weights, place them into `checkpoints/`.
|
||||
Next, start the demo using
|
||||
|
||||
@@ -166,6 +190,7 @@ not the same as in previous Stable Diffusion 1.x/2.x versions.
|
||||
|
||||
To run the script you need to either have a working installation as above or
|
||||
try an _experimental_ import using only a minimal amount of packages:
|
||||
|
||||
```bash
|
||||
python -m venv .detect
|
||||
source .detect/bin/activate
|
||||
@@ -177,6 +202,7 @@ pip install --no-deps invisible-watermark
|
||||
To run the script you need to have a working installation as above. The script
|
||||
is then useable in the following ways (don't forget to activate your
|
||||
virtual environment beforehand, e.g. `source .pt1/bin/activate`):
|
||||
|
||||
```bash
|
||||
# test a single file
|
||||
python scripts/demo/detect.py <your filename here>
|
||||
@@ -203,11 +229,21 @@ run
|
||||
python main.py --base configs/example_training/toy/mnist_cond.yaml
|
||||
```
|
||||
|
||||
**NOTE 1:** Using the non-toy-dataset configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml` and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the used dataset (which is expected to stored in tar-file in the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search for comments containing `USER:` in the respective config.
|
||||
**NOTE 1:** Using the non-toy-dataset
|
||||
configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml`
|
||||
and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the
|
||||
used dataset (which is expected to stored in tar-file in
|
||||
the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search
|
||||
for comments containing `USER:` in the respective config.
|
||||
|
||||
**NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`, only `pytorch1.13` is supported.
|
||||
**NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for
|
||||
autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`,
|
||||
only `pytorch1.13` is supported.
|
||||
|
||||
**NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done for the provided text-to-image configs.
|
||||
**NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires
|
||||
retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing
|
||||
the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done
|
||||
for the provided text-to-image configs.
|
||||
|
||||
### Building New Diffusion Models
|
||||
|
||||
@@ -216,7 +252,8 @@ python main.py --base configs/example_training/toy/mnist_cond.yaml
|
||||
The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of
|
||||
different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model.
|
||||
All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free
|
||||
guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for text-conditioning or `cls` for class-conditioning.
|
||||
guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for
|
||||
text-conditioning or `cls` for class-conditioning.
|
||||
When computing conditionings, the embedder will get `batch[input_key]` as input.
|
||||
We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated
|
||||
appropriately.
|
||||
@@ -229,7 +266,8 @@ enough as we plan to experiment with transformer-based diffusion backbones.
|
||||
|
||||
#### Loss
|
||||
|
||||
The loss is configured through `loss_config`. For standard diffusion model training, you will have to set `sigma_sampler_config`.
|
||||
The loss is configured through `loss_config`. For standard diffusion model training, you will have to
|
||||
set `sigma_sampler_config`.
|
||||
|
||||
#### Sampler config
|
||||
|
||||
@@ -239,8 +277,9 @@ guidance.
|
||||
|
||||
### Dataset Handling
|
||||
|
||||
|
||||
For large scale training we recommend using the data pipelines from our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement and automatically included when following the steps from the [Installation section](#installation).
|
||||
For large scale training we recommend using the data pipelines from
|
||||
our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement
|
||||
and automatically included when following the steps from the [Installation section](#installation).
|
||||
Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
|
||||
data keys/values,
|
||||
e.g.,
|
||||
|
||||
BIN
assets/test_image.png
Normal file
BIN
assets/test_image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 482 KiB |
BIN
assets/tile.gif
Normal file
BIN
assets/tile.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 18 MiB |
@@ -29,25 +29,14 @@ model:
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [ 1, 2, 4 ]
|
||||
ch_mult: [1, 2, 4]
|
||||
num_res_blocks: 4
|
||||
attn_resolutions: [ ]
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
|
||||
decoder_config:
|
||||
target: sgm.modules.diffusionmodules.model.Decoder
|
||||
params:
|
||||
attn_type: none
|
||||
double_z: False
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [ 1, 2, 4 ]
|
||||
num_res_blocks: 4
|
||||
attn_resolutions: [ ]
|
||||
dropout: 0.0
|
||||
params: ${model.params.encoder_config.params}
|
||||
|
||||
data:
|
||||
target: sgm.data.dataset.StableDataModuleFromConfig
|
||||
@@ -55,18 +44,18 @@ data:
|
||||
train:
|
||||
datapipeline:
|
||||
urls:
|
||||
- "DATA-PATH"
|
||||
- DATA-PATH
|
||||
pipeline_config:
|
||||
shardshuffle: 10000
|
||||
sample_shuffle: 10000
|
||||
|
||||
decoders:
|
||||
- "pil"
|
||||
- pil
|
||||
|
||||
postprocessors:
|
||||
- target: sdata.mappers.TorchVisionImageTransforms
|
||||
params:
|
||||
key: 'jpg'
|
||||
key: jpg
|
||||
transforms:
|
||||
- target: torchvision.transforms.Resize
|
||||
params:
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
model:
|
||||
base_learning_rate: 4.5e-6
|
||||
target: sgm.models.autoencoder.AutoencodingEngine
|
||||
params:
|
||||
input_key: jpg
|
||||
monitor: val/loss/rec
|
||||
disc_start_iter: 0
|
||||
|
||||
encoder_config:
|
||||
target: sgm.modules.diffusionmodules.model.Encoder
|
||||
params:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: true
|
||||
z_channels: 8
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
|
||||
decoder_config:
|
||||
target: sgm.modules.diffusionmodules.model.Decoder
|
||||
params: ${model.params.encoder_config.params}
|
||||
|
||||
regularizer_config:
|
||||
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||
|
||||
loss_config:
|
||||
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
|
||||
params:
|
||||
perceptual_weight: 0.25
|
||||
disc_start: 20001
|
||||
disc_weight: 0.5
|
||||
learn_logvar: True
|
||||
|
||||
regularization_weights:
|
||||
kl_loss: 1.0
|
||||
|
||||
data:
|
||||
target: sgm.data.dataset.StableDataModuleFromConfig
|
||||
params:
|
||||
train:
|
||||
datapipeline:
|
||||
urls:
|
||||
- DATA-PATH
|
||||
pipeline_config:
|
||||
shardshuffle: 10000
|
||||
sample_shuffle: 10000
|
||||
|
||||
decoders:
|
||||
- pil
|
||||
|
||||
postprocessors:
|
||||
- target: sdata.mappers.TorchVisionImageTransforms
|
||||
params:
|
||||
key: jpg
|
||||
transforms:
|
||||
- target: torchvision.transforms.Resize
|
||||
params:
|
||||
size: 256
|
||||
interpolation: 3
|
||||
- target: torchvision.transforms.ToTensor
|
||||
- target: sdata.mappers.Rescaler
|
||||
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
|
||||
params:
|
||||
h_key: height
|
||||
w_key: width
|
||||
|
||||
loader:
|
||||
batch_size: 8
|
||||
num_workers: 4
|
||||
|
||||
|
||||
lightning:
|
||||
strategy:
|
||||
target: pytorch_lightning.strategies.DDPStrategy
|
||||
params:
|
||||
find_unused_parameters: True
|
||||
|
||||
modelcheckpoint:
|
||||
params:
|
||||
every_n_train_steps: 5000
|
||||
|
||||
callbacks:
|
||||
metrics_over_trainsteps_checkpoint:
|
||||
params:
|
||||
every_n_train_steps: 50000
|
||||
|
||||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
enable_autocast: False
|
||||
batch_frequency: 1000
|
||||
max_images: 8
|
||||
increase_log_steps: True
|
||||
|
||||
trainer:
|
||||
devices: 0,
|
||||
limit_val_batches: 50
|
||||
benchmark: True
|
||||
accumulate_grad_batches: 1
|
||||
val_check_interval: 10000
|
||||
@@ -21,8 +21,6 @@ model:
|
||||
params:
|
||||
num_idx: 1000
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||
discretization_config:
|
||||
@@ -32,7 +30,6 @@ model:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 256
|
||||
@@ -42,7 +39,6 @@ model:
|
||||
num_head_channels: 64
|
||||
num_classes: sequential
|
||||
adm_in_channels: 1024
|
||||
use_spatial_transformer: true
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
@@ -51,32 +47,31 @@ model:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
# crossattn cond
|
||||
- is_trainable: True
|
||||
input_key: cls
|
||||
ucg_rate: 0.2
|
||||
target: sgm.modules.encoders.modules.ClassEmbedder
|
||||
params:
|
||||
add_sequence_dim: True # will be used through crossattn then
|
||||
add_sequence_dim: True
|
||||
embed_dim: 1024
|
||||
n_classes: 1000
|
||||
# vector cond
|
||||
|
||||
- is_trainable: False
|
||||
ucg_rate: 0.2
|
||||
input_key: original_size_as_tuple
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
# vector cond
|
||||
outdim: 256
|
||||
|
||||
- is_trainable: False
|
||||
input_key: crop_coords_top_left
|
||||
ucg_rate: 0.2
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
outdim: 256
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
||||
target: sgm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
ckpt_path: CKPT_PATH
|
||||
embed_dim: 4
|
||||
@@ -98,7 +93,9 @@ model:
|
||||
|
||||
loss_fn_config:
|
||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||
params:
|
||||
params:
|
||||
loss_weighting_config:
|
||||
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
||||
params:
|
||||
@@ -127,18 +124,18 @@ data:
|
||||
datapipeline:
|
||||
urls:
|
||||
# USER: adapt this path the root of your custom dataset
|
||||
- "DATA_PATH"
|
||||
- DATA_PATH
|
||||
pipeline_config:
|
||||
shardshuffle: 10000
|
||||
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
|
||||
|
||||
decoders:
|
||||
- "pil"
|
||||
- pil
|
||||
|
||||
postprocessors:
|
||||
- target: sdata.mappers.TorchVisionImageTransforms
|
||||
params:
|
||||
key: 'jpg' # USER: you might wanna adapt this for your custom dataset
|
||||
key: jpg # USER: you might wanna adapt this for your custom dataset
|
||||
transforms:
|
||||
- target: torchvision.transforms.Resize
|
||||
params:
|
||||
|
||||
@@ -5,10 +5,6 @@ model:
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||
params:
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
|
||||
params:
|
||||
sigma_data: 1.0
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
||||
params:
|
||||
@@ -17,7 +13,6 @@ model:
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
model_channels: 32
|
||||
@@ -46,6 +41,10 @@ model:
|
||||
loss_fn_config:
|
||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||
params:
|
||||
loss_weighting_config:
|
||||
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
||||
params:
|
||||
sigma_data: 1.0
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
||||
|
||||
|
||||
@@ -5,10 +5,6 @@ model:
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||
params:
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
|
||||
params:
|
||||
sigma_data: 1.0
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
||||
params:
|
||||
@@ -17,7 +13,6 @@ model:
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
in_channels: 1
|
||||
out_channels: 1
|
||||
model_channels: 32
|
||||
@@ -32,6 +27,10 @@ model:
|
||||
loss_fn_config:
|
||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||
params:
|
||||
loss_weighting_config:
|
||||
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
||||
params:
|
||||
sigma_data: 1.0
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
||||
|
||||
|
||||
@@ -5,10 +5,6 @@ model:
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||
params:
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
|
||||
params:
|
||||
sigma_data: 1.0
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
||||
params:
|
||||
@@ -17,13 +13,12 @@ model:
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
in_channels: 1
|
||||
out_channels: 1
|
||||
model_channels: 32
|
||||
attention_resolutions: [ ]
|
||||
attention_resolutions: []
|
||||
num_res_blocks: 4
|
||||
channel_mult: [ 1, 2, 2 ]
|
||||
channel_mult: [1, 2, 2]
|
||||
num_head_channels: 32
|
||||
num_classes: sequential
|
||||
adm_in_channels: 128
|
||||
@@ -33,7 +28,7 @@ model:
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: True
|
||||
input_key: "cls"
|
||||
input_key: cls
|
||||
ucg_rate: 0.2
|
||||
target: sgm.modules.encoders.modules.ClassEmbedder
|
||||
params:
|
||||
@@ -46,6 +41,10 @@ model:
|
||||
loss_fn_config:
|
||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||
params:
|
||||
loss_weighting_config:
|
||||
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
||||
params:
|
||||
sigma_data: 1.0
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
||||
|
||||
|
||||
@@ -7,8 +7,6 @@ model:
|
||||
params:
|
||||
num_idx: 1000
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
||||
discretization_config:
|
||||
@@ -17,13 +15,12 @@ model:
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
in_channels: 1
|
||||
out_channels: 1
|
||||
model_channels: 32
|
||||
attention_resolutions: [ ]
|
||||
attention_resolutions: []
|
||||
num_res_blocks: 4
|
||||
channel_mult: [ 1, 2, 2 ]
|
||||
channel_mult: [1, 2, 2]
|
||||
num_head_channels: 32
|
||||
num_classes: sequential
|
||||
adm_in_channels: 128
|
||||
@@ -33,7 +30,7 @@ model:
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: True
|
||||
input_key: "cls"
|
||||
input_key: cls
|
||||
ucg_rate: 0.2
|
||||
target: sgm.modules.encoders.modules.ClassEmbedder
|
||||
params:
|
||||
@@ -46,6 +43,8 @@ model:
|
||||
loss_fn_config:
|
||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||
params:
|
||||
loss_weighting_config:
|
||||
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
||||
params:
|
||||
|
||||
@@ -5,10 +5,6 @@ model:
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||
params:
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
|
||||
params:
|
||||
sigma_data: 1.0
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
||||
params:
|
||||
@@ -17,7 +13,6 @@ model:
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
in_channels: 1
|
||||
out_channels: 1
|
||||
model_channels: 32
|
||||
@@ -25,7 +20,7 @@ model:
|
||||
num_res_blocks: 4
|
||||
channel_mult: [1, 2, 2]
|
||||
num_head_channels: 32
|
||||
num_classes: "sequential"
|
||||
num_classes: sequential
|
||||
adm_in_channels: 128
|
||||
|
||||
conditioner_config:
|
||||
@@ -33,7 +28,7 @@ model:
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: True
|
||||
input_key: "cls"
|
||||
input_key: cls
|
||||
ucg_rate: 0.2
|
||||
target: sgm.modules.encoders.modules.ClassEmbedder
|
||||
params:
|
||||
@@ -46,6 +41,11 @@ model:
|
||||
loss_fn_config:
|
||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||
params:
|
||||
loss_type: l1
|
||||
loss_weighting_config:
|
||||
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
||||
params:
|
||||
sigma_data: 1.0
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
||||
|
||||
@@ -62,11 +62,6 @@ model:
|
||||
params:
|
||||
scale: 3.0
|
||||
|
||||
loss_config:
|
||||
target: sgm.modules.diffusionmodules.StandardDiffusionLoss
|
||||
params:
|
||||
type: l1
|
||||
|
||||
data:
|
||||
target: sgm.data.mnist.MNISTLoader
|
||||
params:
|
||||
|
||||
@@ -7,10 +7,6 @@ model:
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||
params:
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
|
||||
params:
|
||||
sigma_data: 1.0
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
||||
params:
|
||||
@@ -19,7 +15,6 @@ model:
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
in_channels: 1
|
||||
out_channels: 1
|
||||
model_channels: 32
|
||||
@@ -48,6 +43,10 @@ model:
|
||||
loss_fn_config:
|
||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||
params:
|
||||
loss_weighting_config:
|
||||
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
||||
params:
|
||||
sigma_data: 1.0
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
||||
|
||||
|
||||
@@ -10,19 +10,17 @@ model:
|
||||
scheduler_config:
|
||||
target: sgm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ]
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
warm_up_steps: [10000]
|
||||
cycle_lengths: [10000000000000]
|
||||
f_start: [1.e-6]
|
||||
f_max: [1.]
|
||||
f_min: [1.]
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||
params:
|
||||
num_idx: 1000
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||
discretization_config:
|
||||
@@ -32,18 +30,16 @@ model:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 1, 2, 4 ]
|
||||
attention_resolutions: [1, 2, 4]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
num_head_channels: 64
|
||||
num_classes: sequential
|
||||
adm_in_channels: 1792
|
||||
num_heads: 1
|
||||
use_spatial_transformer: true
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
@@ -52,7 +48,6 @@ model:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
# crossattn cond
|
||||
- is_trainable: True
|
||||
input_key: txt
|
||||
ucg_rate: 0.1
|
||||
@@ -60,23 +55,23 @@ model:
|
||||
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
always_return_pooled: True
|
||||
# vector cond
|
||||
|
||||
- is_trainable: False
|
||||
ucg_rate: 0.1
|
||||
input_key: original_size_as_tuple
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
# vector cond
|
||||
outdim: 256
|
||||
|
||||
- is_trainable: False
|
||||
input_key: crop_coords_top_left
|
||||
ucg_rate: 0.1
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
outdim: 256
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
||||
target: sgm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
ckpt_path: CKPT_PATH
|
||||
embed_dim: 4
|
||||
@@ -99,6 +94,8 @@ model:
|
||||
loss_fn_config:
|
||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||
params:
|
||||
loss_weighting_config:
|
||||
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
||||
params:
|
||||
@@ -127,18 +124,18 @@ data:
|
||||
datapipeline:
|
||||
urls:
|
||||
# USER: adapt this path the root of your custom dataset
|
||||
- "DATA_PATH"
|
||||
- DATA_PATH
|
||||
pipeline_config:
|
||||
shardshuffle: 10000
|
||||
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
|
||||
|
||||
decoders:
|
||||
- "pil"
|
||||
- pil
|
||||
|
||||
postprocessors:
|
||||
- target: sdata.mappers.TorchVisionImageTransforms
|
||||
params:
|
||||
key: 'jpg' # USER: you might wanna adapt this for your custom dataset
|
||||
key: jpg # USER: you might wanna adapt this for your custom dataset
|
||||
transforms:
|
||||
- target: torchvision.transforms.Resize
|
||||
params:
|
||||
|
||||
@@ -10,19 +10,17 @@ model:
|
||||
scheduler_config:
|
||||
target: sgm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ]
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
warm_up_steps: [10000]
|
||||
cycle_lengths: [10000000000000]
|
||||
f_start: [1.e-6]
|
||||
f_max: [1.]
|
||||
f_min: [1.]
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||
params:
|
||||
num_idx: 1000
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||
discretization_config:
|
||||
@@ -32,18 +30,16 @@ model:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 1, 2, 4 ]
|
||||
attention_resolutions: [1, 2, 4]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
num_head_channels: 64
|
||||
num_classes: sequential
|
||||
adm_in_channels: 1792
|
||||
num_heads: 1
|
||||
use_spatial_transformer: true
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
@@ -52,30 +48,30 @@ model:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
# crossattn cond
|
||||
- is_trainable: True
|
||||
input_key: txt
|
||||
ucg_rate: 0.1
|
||||
legacy_ucg_value: ""
|
||||
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
always_return_pooled: True
|
||||
# vector cond
|
||||
|
||||
- is_trainable: False
|
||||
ucg_rate: 0.1
|
||||
input_key: original_size_as_tuple
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
# vector cond
|
||||
outdim: 256
|
||||
|
||||
- is_trainable: False
|
||||
input_key: crop_coords_top_left
|
||||
ucg_rate: 0.1
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
outdim: 256
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
||||
target: sgm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
ckpt_path: CKPT_PATH
|
||||
embed_dim: 4
|
||||
@@ -88,9 +84,9 @@ model:
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [ 1, 2, 4, 4 ]
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: [ ]
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
@@ -98,6 +94,8 @@ model:
|
||||
loss_fn_config:
|
||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||
params:
|
||||
loss_weighting_config:
|
||||
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
||||
params:
|
||||
@@ -126,19 +124,19 @@ data:
|
||||
datapipeline:
|
||||
urls:
|
||||
# USER: adapt this path the root of your custom dataset
|
||||
- "DATA_PATH"
|
||||
- DATA_PATH
|
||||
pipeline_config:
|
||||
shardshuffle: 10000
|
||||
sample_shuffle: 10000
|
||||
|
||||
|
||||
decoders:
|
||||
- "pil"
|
||||
- pil
|
||||
|
||||
postprocessors:
|
||||
- target: sdata.mappers.TorchVisionImageTransforms
|
||||
params:
|
||||
key: 'jpg' # USER: you might wanna adapt this for your custom dataset
|
||||
key: jpg # USER: you might wanna adapt this for your custom dataset
|
||||
transforms:
|
||||
- target: torchvision.transforms.Resize
|
||||
params:
|
||||
|
||||
@@ -9,8 +9,6 @@ model:
|
||||
params:
|
||||
num_idx: 1000
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||
discretization_config:
|
||||
@@ -20,7 +18,6 @@ model:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
@@ -28,17 +25,14 @@ model:
|
||||
num_res_blocks: 2
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
num_head_channels: 64
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
# crossattn cond
|
||||
- is_trainable: False
|
||||
input_key: txt
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
@@ -47,7 +41,7 @@ model:
|
||||
layer: penultimate
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
||||
target: sgm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
|
||||
@@ -9,8 +9,6 @@ model:
|
||||
params:
|
||||
num_idx: 1000
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.VWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
|
||||
discretization_config:
|
||||
@@ -20,7 +18,6 @@ model:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
@@ -28,17 +25,14 @@ model:
|
||||
num_res_blocks: 2
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
num_head_channels: 64
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
# crossattn cond
|
||||
- is_trainable: False
|
||||
input_key: txt
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
@@ -47,7 +41,7 @@ model:
|
||||
layer: penultimate
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
||||
target: sgm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
|
||||
@@ -9,8 +9,6 @@ model:
|
||||
params:
|
||||
num_idx: 1000
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||
discretization_config:
|
||||
@@ -29,25 +27,22 @@ model:
|
||||
num_res_blocks: 2
|
||||
channel_mult: [1, 2, 4]
|
||||
num_head_channels: 64
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
|
||||
transformer_depth: [1, 2, 10]
|
||||
context_dim: 2048
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
legacy: False
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
# crossattn cond
|
||||
- is_trainable: False
|
||||
input_key: txt
|
||||
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
layer: hidden
|
||||
layer_idx: 11
|
||||
# crossattn and vector cond
|
||||
|
||||
- is_trainable: False
|
||||
input_key: txt
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
||||
@@ -58,27 +53,27 @@ model:
|
||||
layer: penultimate
|
||||
always_return_pooled: True
|
||||
legacy: False
|
||||
# vector cond
|
||||
|
||||
- is_trainable: False
|
||||
input_key: original_size_as_tuple
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
# vector cond
|
||||
outdim: 256
|
||||
|
||||
- is_trainable: False
|
||||
input_key: crop_coords_top_left
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
# vector cond
|
||||
outdim: 256
|
||||
|
||||
- is_trainable: False
|
||||
input_key: target_size_as_tuple
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
outdim: 256
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
||||
target: sgm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
|
||||
@@ -9,8 +9,6 @@ model:
|
||||
params:
|
||||
num_idx: 1000
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||
discretization_config:
|
||||
@@ -29,18 +27,15 @@ model:
|
||||
num_res_blocks: 2
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
num_head_channels: 64
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 4
|
||||
context_dim: [1280, 1280, 1280, 1280] # 1280
|
||||
context_dim: [1280, 1280, 1280, 1280]
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
legacy: False
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
# crossattn and vector cond
|
||||
- is_trainable: False
|
||||
input_key: txt
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
||||
@@ -51,27 +46,27 @@ model:
|
||||
freeze: True
|
||||
layer: penultimate
|
||||
always_return_pooled: True
|
||||
# vector cond
|
||||
|
||||
- is_trainable: False
|
||||
input_key: original_size_as_tuple
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
# vector cond
|
||||
outdim: 256
|
||||
|
||||
- is_trainable: False
|
||||
input_key: crop_coords_top_left
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
# vector cond
|
||||
outdim: 256
|
||||
|
||||
- is_trainable: False
|
||||
input_key: aesthetic_score
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by one
|
||||
outdim: 256
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
||||
target: sgm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
|
||||
131
configs/inference/svd.yaml
Normal file
131
configs/inference/svd.yaml
Normal file
@@ -0,0 +1,131 @@
|
||||
model:
|
||||
target: sgm.models.diffusion.DiffusionEngine
|
||||
params:
|
||||
scale_factor: 0.18215
|
||||
disable_first_stage_autocast: True
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||
params:
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||
params:
|
||||
adm_in_channels: 768
|
||||
num_classes: sequential
|
||||
use_checkpoint: True
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [4, 2, 1]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
num_head_channels: 64
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
extra_ff_mix_layer: True
|
||||
use_spatial_context: True
|
||||
merge_strategy: learned_with_images
|
||||
video_kernel_size: [3, 1, 1]
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: False
|
||||
input_key: cond_frames_without_noise
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||
params:
|
||||
n_cond_frames: 1
|
||||
n_copies: 1
|
||||
open_clip_embedding_config:
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
|
||||
- input_key: fps_id
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
- input_key: motion_bucket_id
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
- input_key: cond_frames
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
params:
|
||||
disable_encoder_autocast: True
|
||||
n_cond_frames: 1
|
||||
n_copies: 1
|
||||
is_ae: True
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
- input_key: cond_aug
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencodingEngine
|
||||
params:
|
||||
loss_config:
|
||||
target: torch.nn.Identity
|
||||
regularizer_config:
|
||||
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||
encoder_config:
|
||||
target: sgm.modules.diffusionmodules.model.Encoder
|
||||
params:
|
||||
attn_type: vanilla
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
decoder_config:
|
||||
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
|
||||
params:
|
||||
attn_type: vanilla
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
video_kernel_size: [3, 1, 1]
|
||||
114
configs/inference/svd_image_decoder.yaml
Normal file
114
configs/inference/svd_image_decoder.yaml
Normal file
@@ -0,0 +1,114 @@
|
||||
model:
|
||||
target: sgm.models.diffusion.DiffusionEngine
|
||||
params:
|
||||
scale_factor: 0.18215
|
||||
disable_first_stage_autocast: True
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||
params:
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||
params:
|
||||
adm_in_channels: 768
|
||||
num_classes: sequential
|
||||
use_checkpoint: True
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [4, 2, 1]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
num_head_channels: 64
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
extra_ff_mix_layer: True
|
||||
use_spatial_context: True
|
||||
merge_strategy: learned_with_images
|
||||
video_kernel_size: [3, 1, 1]
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: False
|
||||
input_key: cond_frames_without_noise
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||
params:
|
||||
n_cond_frames: 1
|
||||
n_copies: 1
|
||||
open_clip_embedding_config:
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
|
||||
- input_key: fps_id
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
- input_key: motion_bucket_id
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
- input_key: cond_frames
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
params:
|
||||
disable_encoder_autocast: True
|
||||
n_cond_frames: 1
|
||||
n_copies: 1
|
||||
is_ae: True
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
- input_key: cond_aug
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
31
model_licenses/LICENSE-SDV
Normal file
31
model_licenses/LICENSE-SDV
Normal file
@@ -0,0 +1,31 @@
|
||||
STABLE VIDEO DIFFUSION NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT
|
||||
Dated: November 21, 2023
|
||||
|
||||
“AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
|
||||
|
||||
"Agreement" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein.
|
||||
"Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
|
||||
“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
|
||||
|
||||
"Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
||||
|
||||
"Stability AI" or "we" means Stability AI Ltd.
|
||||
|
||||
"Software" means, collectively, Stability AI’s proprietary StableCode made available under this Agreement.
|
||||
|
||||
“Software Products” means Software and Documentation.
|
||||
|
||||
By using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement.
|
||||
|
||||
|
||||
|
||||
License Rights and Redistribution.
|
||||
Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create Derivative Works of the Software Products for purposes other than commercial or production use.
|
||||
b. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "Stable Video Diffusion is licensed under the Stable Video Diffusion Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
|
||||
2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS.
|
||||
3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
||||
3. Intellectual Property.
|
||||
a. No trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products.
|
||||
Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works.
|
||||
If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement.
|
||||
4. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Software Products. Sections 2-4 shall survive the termination of this Agreement.
|
||||
59
scripts/demo/discretization.py
Normal file
59
scripts/demo/discretization.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
|
||||
from sgm.modules.diffusionmodules.discretizer import Discretization
|
||||
|
||||
|
||||
class Img2ImgDiscretizationWrapper:
|
||||
"""
|
||||
wraps a discretizer, and prunes the sigmas
|
||||
params:
|
||||
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
||||
"""
|
||||
|
||||
def __init__(self, discretization: Discretization, strength: float = 1.0):
|
||||
self.discretization = discretization
|
||||
self.strength = strength
|
||||
assert 0.0 <= self.strength <= 1.0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# sigmas start large first, and decrease then
|
||||
sigmas = self.discretization(*args, **kwargs)
|
||||
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
||||
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
print(f"sigmas after pruning: ", sigmas)
|
||||
return sigmas
|
||||
|
||||
|
||||
class Txt2NoisyDiscretizationWrapper:
|
||||
"""
|
||||
wraps a discretizer, and prunes the sigmas
|
||||
params:
|
||||
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, discretization: Discretization, strength: float = 0.0, original_steps=None
|
||||
):
|
||||
self.discretization = discretization
|
||||
self.strength = strength
|
||||
self.original_steps = original_steps
|
||||
assert 0.0 <= self.strength <= 1.0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# sigmas start large first, and decrease then
|
||||
sigmas = self.discretization(*args, **kwargs)
|
||||
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
if self.original_steps is None:
|
||||
steps = len(sigmas)
|
||||
else:
|
||||
steps = self.original_steps + 1
|
||||
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
|
||||
sigmas = sigmas[prune_index:]
|
||||
print("prune index:", prune_index)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
print(f"sigmas after pruning: ", sigmas)
|
||||
return sigmas
|
||||
@@ -253,7 +253,10 @@ if __name__ == "__main__":
|
||||
st.title("Stable Diffusion")
|
||||
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||
version_dict = VERSION2SPECS[version]
|
||||
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
|
||||
if st.checkbox("Load Model"):
|
||||
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
|
||||
else:
|
||||
mode = "skip"
|
||||
st.write("__________________________")
|
||||
|
||||
set_lowvram_mode(st.checkbox("Low vram mode", True))
|
||||
@@ -269,10 +272,11 @@ if __name__ == "__main__":
|
||||
|
||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
|
||||
|
||||
state = init_st(version_dict, load_filter=True)
|
||||
if state["msg"]:
|
||||
st.info(state["msg"])
|
||||
model = state["model"]
|
||||
if mode != "skip":
|
||||
state = init_st(version_dict, load_filter=True)
|
||||
if state["msg"]:
|
||||
st.info(state["msg"])
|
||||
model = state["model"]
|
||||
|
||||
is_legacy = version_dict["is_legacy"]
|
||||
|
||||
@@ -333,6 +337,8 @@ if __name__ == "__main__":
|
||||
filter=state.get("filter"),
|
||||
stage2strength=stage2strength,
|
||||
)
|
||||
elif mode == "skip":
|
||||
out = None
|
||||
else:
|
||||
raise ValueError(f"unknown mode {mode}")
|
||||
if isinstance(out, (tuple, list)):
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
from typing import List, Union
|
||||
from glob import glob
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as TT
|
||||
from einops import rearrange, repeat
|
||||
from imwatermark import WatermarkEncoder
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
@@ -12,63 +17,22 @@ from PIL import Image
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
from torch import autocast
|
||||
from torchvision import transforms
|
||||
from torchvision.utils import make_grid
|
||||
from torchvision.utils import make_grid, save_image
|
||||
|
||||
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||
from sgm.modules.diffusionmodules.sampling import (
|
||||
DPMPP2MSampler,
|
||||
DPMPP2SAncestralSampler,
|
||||
EulerAncestralSampler,
|
||||
EulerEDMSampler,
|
||||
HeunEDMSampler,
|
||||
LinearMultistepSampler,
|
||||
)
|
||||
from sgm.util import append_dims, instantiate_from_config
|
||||
|
||||
|
||||
class WatermarkEmbedder:
|
||||
def __init__(self, watermark):
|
||||
self.watermark = watermark
|
||||
self.num_bits = len(WATERMARK_BITS)
|
||||
self.encoder = WatermarkEncoder()
|
||||
self.encoder.set_watermark("bits", self.watermark)
|
||||
|
||||
def __call__(self, image: torch.Tensor):
|
||||
"""
|
||||
Adds a predefined watermark to the input image
|
||||
|
||||
Args:
|
||||
image: ([N,] B, C, H, W) in range [0, 1]
|
||||
|
||||
Returns:
|
||||
same as input but watermarked
|
||||
"""
|
||||
# watermarking libary expects input as cv2 BGR format
|
||||
squeeze = len(image.shape) == 4
|
||||
if squeeze:
|
||||
image = image[None, ...]
|
||||
n = image.shape[0]
|
||||
image_np = rearrange(
|
||||
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
||||
).numpy()[:, :, :, ::-1]
|
||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||
for k in range(image_np.shape[0]):
|
||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||
image = torch.from_numpy(
|
||||
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
|
||||
).to(image.device)
|
||||
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
||||
if squeeze:
|
||||
image = image[0]
|
||||
return image
|
||||
|
||||
|
||||
# A fixed 48-bit message that was choosen at random
|
||||
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
||||
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
||||
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
||||
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
||||
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
|
||||
from scripts.demo.discretization import (Img2ImgDiscretizationWrapper,
|
||||
Txt2NoisyDiscretizationWrapper)
|
||||
from scripts.util.detection.nsfw_and_watermark_dectection import \
|
||||
DeepFloydDataFiltering
|
||||
from sgm.inference.helpers import embed_watermark
|
||||
from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider,
|
||||
VanillaCFG)
|
||||
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
|
||||
DPMPP2SAncestralSampler,
|
||||
EulerAncestralSampler,
|
||||
EulerEDMSampler,
|
||||
HeunEDMSampler,
|
||||
LinearMultistepSampler)
|
||||
from sgm.util import append_dims, default, instantiate_from_config
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
@@ -164,11 +128,12 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
if prompt is None:
|
||||
prompt = st.text_input(
|
||||
"Prompt", "A professional photograph of an astronaut riding a pig"
|
||||
)
|
||||
prompt = "A professional photograph of an astronaut riding a pig"
|
||||
if negative_prompt is None:
|
||||
negative_prompt = st.text_input("Negative prompt", "")
|
||||
negative_prompt = ""
|
||||
|
||||
prompt = st.text_input("Prompt", prompt)
|
||||
negative_prompt = st.text_input("Negative prompt", negative_prompt)
|
||||
|
||||
value_dict["prompt"] = prompt
|
||||
value_dict["negative_prompt"] = negative_prompt
|
||||
@@ -203,13 +168,35 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
||||
value_dict["target_width"] = init_dict["target_width"]
|
||||
value_dict["target_height"] = init_dict["target_height"]
|
||||
|
||||
if key in ["fps_id", "fps"]:
|
||||
fps = st.number_input("fps", value=6, min_value=1)
|
||||
|
||||
value_dict["fps"] = fps
|
||||
value_dict["fps_id"] = fps - 1
|
||||
|
||||
if key == "motion_bucket_id":
|
||||
mb_id = st.number_input("motion bucket id", 0, 511, value=127)
|
||||
value_dict["motion_bucket_id"] = mb_id
|
||||
|
||||
if key == "pool_image":
|
||||
st.text("Image for pool conditioning")
|
||||
image = load_img(
|
||||
key="pool_image_input",
|
||||
size=224,
|
||||
center_crop=True,
|
||||
)
|
||||
if image is None:
|
||||
st.info("Need an image here")
|
||||
image = torch.zeros(1, 3, 224, 224)
|
||||
value_dict["pool_image"] = image
|
||||
|
||||
return value_dict
|
||||
|
||||
|
||||
def perform_save_locally(save_path, samples):
|
||||
os.makedirs(os.path.join(save_path), exist_ok=True)
|
||||
base_count = len(os.listdir(os.path.join(save_path)))
|
||||
samples = embed_watemark(samples)
|
||||
samples = embed_watermark(samples)
|
||||
for sample in samples:
|
||||
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
||||
Image.fromarray(sample.astype(np.uint8)).save(
|
||||
@@ -228,95 +215,99 @@ def init_save_locally(_dir, init_value: bool = False):
|
||||
return save_locally, save_path
|
||||
|
||||
|
||||
class Img2ImgDiscretizationWrapper:
|
||||
"""
|
||||
wraps a discretizer, and prunes the sigmas
|
||||
params:
|
||||
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
||||
"""
|
||||
|
||||
def __init__(self, discretization, strength: float = 1.0):
|
||||
self.discretization = discretization
|
||||
self.strength = strength
|
||||
assert 0.0 <= self.strength <= 1.0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# sigmas start large first, and decrease then
|
||||
sigmas = self.discretization(*args, **kwargs)
|
||||
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
||||
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
print(f"sigmas after pruning: ", sigmas)
|
||||
return sigmas
|
||||
|
||||
|
||||
class Txt2NoisyDiscretizationWrapper:
|
||||
"""
|
||||
wraps a discretizer, and prunes the sigmas
|
||||
params:
|
||||
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
|
||||
"""
|
||||
|
||||
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
|
||||
self.discretization = discretization
|
||||
self.strength = strength
|
||||
self.original_steps = original_steps
|
||||
assert 0.0 <= self.strength <= 1.0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# sigmas start large first, and decrease then
|
||||
sigmas = self.discretization(*args, **kwargs)
|
||||
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
if self.original_steps is None:
|
||||
steps = len(sigmas)
|
||||
else:
|
||||
steps = self.original_steps + 1
|
||||
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
|
||||
sigmas = sigmas[prune_index:]
|
||||
print("prune index:", prune_index)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
print(f"sigmas after pruning: ", sigmas)
|
||||
return sigmas
|
||||
|
||||
|
||||
def get_guider(key):
|
||||
def get_guider(options, key):
|
||||
guider = st.sidebar.selectbox(
|
||||
f"Discretization #{key}",
|
||||
[
|
||||
"VanillaCFG",
|
||||
"IdentityGuider",
|
||||
"LinearPredictionGuider",
|
||||
],
|
||||
options.get("guider", 0),
|
||||
)
|
||||
|
||||
additional_guider_kwargs = options.pop("additional_guider_kwargs", {})
|
||||
|
||||
if guider == "IdentityGuider":
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
||||
}
|
||||
elif guider == "VanillaCFG":
|
||||
scale = st.number_input(
|
||||
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
|
||||
scale_schedule = st.sidebar.selectbox(
|
||||
f"Scale schedule #{key}",
|
||||
["Identity", "Oscillating"],
|
||||
)
|
||||
|
||||
thresholder = st.sidebar.selectbox(
|
||||
f"Thresholder #{key}",
|
||||
[
|
||||
"None",
|
||||
],
|
||||
)
|
||||
if scale_schedule == "Identity":
|
||||
scale = st.number_input(
|
||||
f"cfg-scale #{key}",
|
||||
value=options.get("cfg", 5.0),
|
||||
min_value=0.0,
|
||||
)
|
||||
|
||||
if thresholder == "None":
|
||||
dyn_thresh_config = {
|
||||
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
||||
scale_schedule_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.IdentitySchedule",
|
||||
"params": {"scale": scale},
|
||||
}
|
||||
|
||||
elif scale_schedule == "Oscillating":
|
||||
small_scale = st.number_input(
|
||||
f"small cfg-scale #{key}",
|
||||
value=4.0,
|
||||
min_value=0.0,
|
||||
)
|
||||
|
||||
large_scale = st.number_input(
|
||||
f"large cfg-scale #{key}",
|
||||
value=16.0,
|
||||
min_value=0.0,
|
||||
)
|
||||
|
||||
sigma_cutoff = st.number_input(
|
||||
f"sigma cutoff #{key}",
|
||||
value=1.0,
|
||||
min_value=0.0,
|
||||
)
|
||||
|
||||
scale_schedule_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.OscillatingSchedule",
|
||||
"params": {
|
||||
"small_scale": small_scale,
|
||||
"large_scale": large_scale,
|
||||
"sigma_cutoff": sigma_cutoff,
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
||||
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
|
||||
"params": {
|
||||
"scale_schedule_config": scale_schedule_config,
|
||||
**additional_guider_kwargs,
|
||||
},
|
||||
}
|
||||
elif guider == "LinearPredictionGuider":
|
||||
max_scale = st.number_input(
|
||||
f"max-cfg-scale #{key}",
|
||||
value=options.get("cfg", 1.5),
|
||||
min_value=1.0,
|
||||
)
|
||||
min_scale = st.number_input(
|
||||
f"min guidance scale",
|
||||
value=options.get("min_cfg", 1.0),
|
||||
min_value=1.0,
|
||||
max_value=10.0,
|
||||
)
|
||||
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider",
|
||||
"params": {
|
||||
"max_scale": max_scale,
|
||||
"min_scale": min_scale,
|
||||
"num_frames": options["num_frames"],
|
||||
**additional_guider_kwargs,
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -325,18 +316,21 @@ def get_guider(key):
|
||||
|
||||
def init_sampling(
|
||||
key=1,
|
||||
img2img_strength=1.0,
|
||||
specify_num_samples=True,
|
||||
stage2strength=None,
|
||||
img2img_strength: Optional[float] = None,
|
||||
specify_num_samples: bool = True,
|
||||
stage2strength: Optional[float] = None,
|
||||
options: Optional[Dict[str, int]] = None,
|
||||
):
|
||||
options = {} if options is None else options
|
||||
|
||||
num_rows, num_cols = 1, 1
|
||||
if specify_num_samples:
|
||||
num_cols = st.number_input(
|
||||
f"num cols #{key}", value=2, min_value=1, max_value=10
|
||||
f"num cols #{key}", value=num_cols, min_value=1, max_value=10
|
||||
)
|
||||
|
||||
steps = st.sidebar.number_input(
|
||||
f"steps #{key}", value=40, min_value=1, max_value=1000
|
||||
f"steps #{key}", value=options.get("num_steps", 40), min_value=1, max_value=1000
|
||||
)
|
||||
sampler = st.sidebar.selectbox(
|
||||
f"Sampler #{key}",
|
||||
@@ -348,7 +342,7 @@ def init_sampling(
|
||||
"DPMPP2MSampler",
|
||||
"LinearMultistepSampler",
|
||||
],
|
||||
0,
|
||||
options.get("sampler", 0),
|
||||
)
|
||||
discretization = st.sidebar.selectbox(
|
||||
f"Discretization #{key}",
|
||||
@@ -356,14 +350,15 @@ def init_sampling(
|
||||
"LegacyDDPMDiscretization",
|
||||
"EDMDiscretization",
|
||||
],
|
||||
options.get("discretization", 0),
|
||||
)
|
||||
|
||||
discretization_config = get_discretization(discretization, key=key)
|
||||
discretization_config = get_discretization(discretization, options=options, key=key)
|
||||
|
||||
guider_config = get_guider(key=key)
|
||||
guider_config = get_guider(options=options, key=key)
|
||||
|
||||
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
|
||||
if img2img_strength < 1.0:
|
||||
if img2img_strength is not None:
|
||||
st.warning(
|
||||
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
|
||||
)
|
||||
@@ -377,15 +372,19 @@ def init_sampling(
|
||||
return sampler, num_rows, num_cols
|
||||
|
||||
|
||||
def get_discretization(discretization, key=1):
|
||||
def get_discretization(discretization, options, key=1):
|
||||
if discretization == "LegacyDDPMDiscretization":
|
||||
discretization_config = {
|
||||
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
||||
}
|
||||
elif discretization == "EDMDiscretization":
|
||||
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
|
||||
sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146
|
||||
rho = st.number_input(f"rho #{key}", value=3.0)
|
||||
sigma_min = st.number_input(
|
||||
f"sigma_min #{key}", value=options.get("sigma_min", 0.03)
|
||||
) # 0.0292
|
||||
sigma_max = st.number_input(
|
||||
f"sigma_max #{key}", value=options.get("sigma_max", 14.61)
|
||||
) # 14.6146
|
||||
rho = st.number_input(f"rho #{key}", value=options.get("rho", 3.0))
|
||||
discretization_config = {
|
||||
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
||||
"params": {
|
||||
@@ -474,8 +473,8 @@ def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1
|
||||
return sampler
|
||||
|
||||
|
||||
def get_interactive_image(key=None) -> Image.Image:
|
||||
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
|
||||
def get_interactive_image() -> Image.Image:
|
||||
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
|
||||
if image is not None:
|
||||
image = Image.open(image)
|
||||
if not image.mode == "RGB":
|
||||
@@ -483,8 +482,12 @@ def get_interactive_image(key=None) -> Image.Image:
|
||||
return image
|
||||
|
||||
|
||||
def load_img(display=True, key=None):
|
||||
image = get_interactive_image(key=key)
|
||||
def load_img(
|
||||
display: bool = True,
|
||||
size: Union[None, int, Tuple[int, int]] = None,
|
||||
center_crop: bool = False,
|
||||
):
|
||||
image = get_interactive_image()
|
||||
if image is None:
|
||||
return None
|
||||
if display:
|
||||
@@ -492,12 +495,15 @@ def load_img(display=True, key=None):
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h})")
|
||||
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: x * 2.0 - 1.0),
|
||||
]
|
||||
)
|
||||
transform = []
|
||||
if size is not None:
|
||||
transform.append(transforms.Resize(size))
|
||||
if center_crop:
|
||||
transform.append(transforms.CenterCrop(size))
|
||||
transform.append(transforms.ToTensor())
|
||||
transform.append(transforms.Lambda(lambda x: 2.0 * x - 1.0))
|
||||
|
||||
transform = transforms.Compose(transform)
|
||||
img = transform(image)[None, ...]
|
||||
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
|
||||
return img
|
||||
@@ -518,15 +524,18 @@ def do_sample(
|
||||
W,
|
||||
C,
|
||||
F,
|
||||
force_uc_zero_embeddings: List = None,
|
||||
force_uc_zero_embeddings: Optional[List] = None,
|
||||
force_cond_zero_embeddings: Optional[List] = None,
|
||||
batch2model_input: List = None,
|
||||
return_latents=False,
|
||||
filter=None,
|
||||
T=None,
|
||||
additional_batch_uc_fields=None,
|
||||
decoding_t=None,
|
||||
):
|
||||
if force_uc_zero_embeddings is None:
|
||||
force_uc_zero_embeddings = []
|
||||
if batch2model_input is None:
|
||||
batch2model_input = []
|
||||
force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])
|
||||
batch2model_input = default(batch2model_input, [])
|
||||
additional_batch_uc_fields = default(additional_batch_uc_fields, [])
|
||||
|
||||
st.text("Sampling")
|
||||
|
||||
@@ -535,24 +544,25 @@ def do_sample(
|
||||
with torch.no_grad():
|
||||
with precision_scope("cuda"):
|
||||
with model.ema_scope():
|
||||
num_samples = [num_samples]
|
||||
if T is not None:
|
||||
num_samples = [num_samples, T]
|
||||
else:
|
||||
num_samples = [num_samples]
|
||||
|
||||
load_model(model.conditioner)
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
num_samples,
|
||||
T=T,
|
||||
additional_batch_uc_fields=additional_batch_uc_fields,
|
||||
)
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
print(key, batch[key].shape)
|
||||
elif isinstance(batch[key], list):
|
||||
print(key, [len(l) for l in batch[key]])
|
||||
else:
|
||||
print(key, batch[key])
|
||||
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
||||
)
|
||||
unload_model(model.conditioner)
|
||||
|
||||
@@ -561,10 +571,29 @@ def do_sample(
|
||||
c[k], uc[k] = map(
|
||||
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
|
||||
)
|
||||
if k in ["crossattn", "concat"] and T is not None:
|
||||
uc[k] = repeat(uc[k], "b ... -> b t ...", t=T)
|
||||
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=T)
|
||||
c[k] = repeat(c[k], "b ... -> b t ...", t=T)
|
||||
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=T)
|
||||
|
||||
additional_model_inputs = {}
|
||||
for k in batch2model_input:
|
||||
additional_model_inputs[k] = batch[k]
|
||||
if k == "image_only_indicator":
|
||||
assert T is not None
|
||||
|
||||
if isinstance(
|
||||
sampler.guider, (VanillaCFG, LinearPredictionGuider)
|
||||
):
|
||||
additional_model_inputs[k] = torch.zeros(
|
||||
num_samples[0] * 2, num_samples[1]
|
||||
).to("cuda")
|
||||
else:
|
||||
additional_model_inputs[k] = torch.zeros(num_samples).to(
|
||||
"cuda"
|
||||
)
|
||||
else:
|
||||
additional_model_inputs[k] = batch[k]
|
||||
|
||||
shape = (math.prod(num_samples), C, H // F, W // F)
|
||||
randn = torch.randn(shape).to("cuda")
|
||||
@@ -581,6 +610,9 @@ def do_sample(
|
||||
unload_model(model.denoiser)
|
||||
|
||||
load_model(model.first_stage_model)
|
||||
model.en_and_decode_n_samples_a_time = (
|
||||
decoding_t # Decode n frames at a time
|
||||
)
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
unload_model(model.first_stage_model)
|
||||
@@ -588,16 +620,32 @@ def do_sample(
|
||||
if filter is not None:
|
||||
samples = filter(samples)
|
||||
|
||||
grid = torch.stack([samples])
|
||||
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
||||
outputs.image(grid.cpu().numpy())
|
||||
if T is None:
|
||||
grid = torch.stack([samples])
|
||||
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
||||
outputs.image(grid.cpu().numpy())
|
||||
else:
|
||||
as_vids = rearrange(samples, "(b t) c h w -> b t c h w", t=T)
|
||||
for i, vid in enumerate(as_vids):
|
||||
grid = rearrange(make_grid(vid, nrow=4), "c h w -> h w c")
|
||||
st.image(
|
||||
grid.cpu().numpy(),
|
||||
f"Sample #{i} as image",
|
||||
)
|
||||
|
||||
if return_latents:
|
||||
return samples, samples_z
|
||||
return samples
|
||||
|
||||
|
||||
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
def get_batch(
|
||||
keys,
|
||||
value_dict: dict,
|
||||
N: Union[List, ListConfig],
|
||||
device: str = "cuda",
|
||||
T: int = None,
|
||||
additional_batch_uc_fields: List[str] = [],
|
||||
):
|
||||
# Hardcoded demo setups; might undergo some changes in the future
|
||||
|
||||
batch = {}
|
||||
@@ -605,21 +653,15 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
batch["txt"] = (
|
||||
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
)
|
||||
batch_uc["txt"] = (
|
||||
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
)
|
||||
batch["txt"] = [value_dict["prompt"]] * math.prod(N)
|
||||
|
||||
batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N)
|
||||
|
||||
elif key == "original_size_as_tuple":
|
||||
batch["original_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
.repeat(math.prod(N), 1)
|
||||
)
|
||||
elif key == "crop_coords_top_left":
|
||||
batch["crop_coords_top_left"] = (
|
||||
@@ -627,30 +669,67 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
||||
)
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
.repeat(math.prod(N), 1)
|
||||
)
|
||||
elif key == "aesthetic_score":
|
||||
batch["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
torch.tensor([value_dict["aesthetic_score"]])
|
||||
.to(device)
|
||||
.repeat(math.prod(N), 1)
|
||||
)
|
||||
batch_uc["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["negative_aesthetic_score"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
.repeat(math.prod(N), 1)
|
||||
)
|
||||
|
||||
elif key == "target_size_as_tuple":
|
||||
batch["target_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
.repeat(math.prod(N), 1)
|
||||
)
|
||||
elif key == "fps":
|
||||
batch[key] = (
|
||||
torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
|
||||
)
|
||||
elif key == "fps_id":
|
||||
batch[key] = (
|
||||
torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
|
||||
)
|
||||
elif key == "motion_bucket_id":
|
||||
batch[key] = (
|
||||
torch.tensor([value_dict["motion_bucket_id"]])
|
||||
.to(device)
|
||||
.repeat(math.prod(N))
|
||||
)
|
||||
elif key == "pool_image":
|
||||
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
|
||||
device, dtype=torch.half
|
||||
)
|
||||
elif key == "cond_aug":
|
||||
batch[key] = repeat(
|
||||
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
|
||||
"1 -> b",
|
||||
b=math.prod(N),
|
||||
)
|
||||
elif key == "cond_frames":
|
||||
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
|
||||
elif key == "cond_frames_without_noise":
|
||||
batch[key] = repeat(
|
||||
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
|
||||
)
|
||||
else:
|
||||
batch[key] = value_dict[key]
|
||||
|
||||
if T is not None:
|
||||
batch["num_video_frames"] = T
|
||||
|
||||
for key in batch.keys():
|
||||
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
||||
batch_uc[key] = torch.clone(batch[key])
|
||||
elif key in additional_batch_uc_fields and key not in batch_uc:
|
||||
batch_uc[key] = copy.copy(batch[key])
|
||||
return batch, batch_uc
|
||||
|
||||
|
||||
@@ -661,7 +740,8 @@ def do_img2img(
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
force_uc_zero_embeddings=[],
|
||||
force_uc_zero_embeddings: Optional[List] = None,
|
||||
force_cond_zero_embeddings: Optional[List] = None,
|
||||
additional_kwargs={},
|
||||
offset_noise_level: int = 0.0,
|
||||
return_latents=False,
|
||||
@@ -686,6 +766,7 @@ def do_img2img(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
||||
)
|
||||
unload_model(model.conditioner)
|
||||
for k in c:
|
||||
@@ -736,9 +817,112 @@ def do_img2img(
|
||||
if filter is not None:
|
||||
samples = filter(samples)
|
||||
|
||||
grid = embed_watemark(torch.stack([samples]))
|
||||
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
||||
outputs.image(grid.cpu().numpy())
|
||||
if return_latents:
|
||||
return samples, samples_z
|
||||
return samples
|
||||
|
||||
|
||||
def get_resizing_factor(
|
||||
desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
|
||||
) -> float:
|
||||
r_bound = desired_shape[1] / desired_shape[0]
|
||||
aspect_r = current_shape[1] / current_shape[0]
|
||||
if r_bound >= 1.0:
|
||||
if aspect_r >= r_bound:
|
||||
factor = min(desired_shape) / min(current_shape)
|
||||
else:
|
||||
if aspect_r < 1.0:
|
||||
factor = max(desired_shape) / min(current_shape)
|
||||
else:
|
||||
factor = max(desired_shape) / max(current_shape)
|
||||
else:
|
||||
if aspect_r <= r_bound:
|
||||
factor = min(desired_shape) / min(current_shape)
|
||||
else:
|
||||
if aspect_r > 1:
|
||||
factor = max(desired_shape) / min(current_shape)
|
||||
else:
|
||||
factor = max(desired_shape) / max(current_shape)
|
||||
|
||||
return factor
|
||||
|
||||
|
||||
def get_interactive_image(key=None) -> Image.Image:
|
||||
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
|
||||
if image is not None:
|
||||
image = Image.open(image)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
def load_img_for_prediction(
|
||||
W: int, H: int, display=True, key=None, device="cuda"
|
||||
) -> torch.Tensor:
|
||||
image = get_interactive_image(key=key)
|
||||
if image is None:
|
||||
return None
|
||||
if display:
|
||||
st.image(image)
|
||||
w, h = image.size
|
||||
|
||||
image = np.array(image).transpose(2, 0, 1)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
rfs = get_resizing_factor((H, W), (h, w))
|
||||
resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
|
||||
top = (resize_size[0] - H) // 2
|
||||
left = (resize_size[1] - W) // 2
|
||||
|
||||
image = torch.nn.functional.interpolate(
|
||||
image, resize_size, mode="area", antialias=False
|
||||
)
|
||||
image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
|
||||
|
||||
if display:
|
||||
numpy_img = np.transpose(image[0].numpy(), (1, 2, 0))
|
||||
pil_image = Image.fromarray((numpy_img * 255).astype(np.uint8))
|
||||
st.image(pil_image)
|
||||
return image.to(device) * 2.0 - 1.0
|
||||
|
||||
|
||||
def save_video_as_grid_and_mp4(
|
||||
video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5
|
||||
):
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
base_count = len(glob(os.path.join(save_path, "*.mp4")))
|
||||
|
||||
video_batch = rearrange(video_batch, "(b t) c h w -> b t c h w", t=T)
|
||||
video_batch = embed_watermark(video_batch)
|
||||
for vid in video_batch:
|
||||
save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4)
|
||||
|
||||
video_path = os.path.join(save_path, f"{base_count:06d}.mp4")
|
||||
|
||||
writer = cv2.VideoWriter(
|
||||
video_path,
|
||||
cv2.VideoWriter_fourcc(*"MP4V"),
|
||||
fps,
|
||||
(vid.shape[-1], vid.shape[-2]),
|
||||
)
|
||||
|
||||
vid = (
|
||||
(rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8)
|
||||
)
|
||||
for frame in vid:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||
writer.write(frame)
|
||||
|
||||
writer.release()
|
||||
|
||||
video_path_h264 = video_path[:-4] + "_h264.mp4"
|
||||
os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}")
|
||||
|
||||
with open(video_path_h264, "rb") as f:
|
||||
video_bytes = f.read()
|
||||
st.video(video_bytes)
|
||||
|
||||
base_count += 1
|
||||
|
||||
200
scripts/demo/video_sampling.py
Normal file
200
scripts/demo/video_sampling.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import os
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from scripts.demo.streamlit_helpers import *
|
||||
|
||||
SAVE_PATH = "outputs/demo/vid/"
|
||||
|
||||
VERSION2SPECS = {
|
||||
"svd": {
|
||||
"T": 14,
|
||||
"H": 576,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"config": "configs/inference/svd.yaml",
|
||||
"ckpt": "checkpoints/svd.safetensors",
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 2.5,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 2,
|
||||
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||
"num_steps": 25,
|
||||
},
|
||||
},
|
||||
"svd_image_decoder": {
|
||||
"T": 14,
|
||||
"H": 576,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"config": "configs/inference/svd_image_decoder.yaml",
|
||||
"ckpt": "checkpoints/svd_image_decoder.safetensors",
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 2.5,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 2,
|
||||
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||
"num_steps": 25,
|
||||
},
|
||||
},
|
||||
"svd_xt": {
|
||||
"T": 25,
|
||||
"H": 576,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"config": "configs/inference/svd.yaml",
|
||||
"ckpt": "checkpoints/svd_xt.safetensors",
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 3.0,
|
||||
"min_cfg": 1.5,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 2,
|
||||
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||
"num_steps": 30,
|
||||
"decoding_t": 14,
|
||||
},
|
||||
},
|
||||
"svd_xt_image_decoder": {
|
||||
"T": 25,
|
||||
"H": 576,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"config": "configs/inference/svd_image_decoder.yaml",
|
||||
"ckpt": "checkpoints/svd_xt_image_decoder.safetensors",
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 3.0,
|
||||
"min_cfg": 1.5,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 2,
|
||||
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||
"num_steps": 30,
|
||||
"decoding_t": 14,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
st.title("Stable Video Diffusion")
|
||||
version = st.selectbox(
|
||||
"Model Version",
|
||||
[k for k in VERSION2SPECS.keys()],
|
||||
0,
|
||||
)
|
||||
version_dict = VERSION2SPECS[version]
|
||||
if st.checkbox("Load Model"):
|
||||
mode = "img2vid"
|
||||
else:
|
||||
mode = "skip"
|
||||
|
||||
H = st.sidebar.number_input(
|
||||
"H", value=version_dict["H"], min_value=64, max_value=2048
|
||||
)
|
||||
W = st.sidebar.number_input(
|
||||
"W", value=version_dict["W"], min_value=64, max_value=2048
|
||||
)
|
||||
T = st.sidebar.number_input(
|
||||
"T", value=version_dict["T"], min_value=0, max_value=128
|
||||
)
|
||||
C = version_dict["C"]
|
||||
F = version_dict["f"]
|
||||
options = version_dict["options"]
|
||||
|
||||
if mode != "skip":
|
||||
state = init_st(version_dict, load_filter=True)
|
||||
if state["msg"]:
|
||||
st.info(state["msg"])
|
||||
model = state["model"]
|
||||
|
||||
ukeys = set(
|
||||
get_unique_embedder_keys_from_conditioner(state["model"].conditioner)
|
||||
)
|
||||
|
||||
value_dict = init_embedder_options(
|
||||
ukeys,
|
||||
{},
|
||||
)
|
||||
|
||||
value_dict["image_only_indicator"] = 0
|
||||
|
||||
if mode == "img2vid":
|
||||
img = load_img_for_prediction(W, H)
|
||||
cond_aug = st.number_input(
|
||||
"Conditioning augmentation:", value=0.02, min_value=0.0
|
||||
)
|
||||
value_dict["cond_frames_without_noise"] = img
|
||||
value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img)
|
||||
value_dict["cond_aug"] = cond_aug
|
||||
|
||||
seed = st.sidebar.number_input(
|
||||
"seed", value=23, min_value=0, max_value=int(1e9)
|
||||
)
|
||||
seed_everything(seed)
|
||||
|
||||
save_locally, save_path = init_save_locally(
|
||||
os.path.join(SAVE_PATH, version), init_value=True
|
||||
)
|
||||
|
||||
options["num_frames"] = T
|
||||
|
||||
sampler, num_rows, num_cols = init_sampling(options=options)
|
||||
num_samples = num_rows * num_cols
|
||||
|
||||
decoding_t = st.number_input(
|
||||
"Decode t frames at a time (set small if you are low on VRAM)",
|
||||
value=options.get("decoding_t", T),
|
||||
min_value=1,
|
||||
max_value=int(1e9),
|
||||
)
|
||||
|
||||
if st.checkbox("Overwrite fps in mp4 generator", False):
|
||||
saving_fps = st.number_input(
|
||||
f"saving video at fps:", value=value_dict["fps"], min_value=1
|
||||
)
|
||||
else:
|
||||
saving_fps = value_dict["fps"]
|
||||
|
||||
if st.button("Sample"):
|
||||
out = do_sample(
|
||||
model,
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
H,
|
||||
W,
|
||||
C,
|
||||
F,
|
||||
T=T,
|
||||
batch2model_input=["num_video_frames", "image_only_indicator"],
|
||||
force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None),
|
||||
force_cond_zero_embeddings=options.get(
|
||||
"force_cond_zero_embeddings", None
|
||||
),
|
||||
return_latents=False,
|
||||
decoding_t=decoding_t,
|
||||
)
|
||||
|
||||
if isinstance(out, (tuple, list)):
|
||||
samples, samples_z = out
|
||||
else:
|
||||
samples = out
|
||||
samples_z = None
|
||||
|
||||
if save_locally:
|
||||
save_video_as_grid_and_mp4(samples, save_path, T, fps=saving_fps)
|
||||
146
scripts/sampling/configs/svd.yaml
Normal file
146
scripts/sampling/configs/svd.yaml
Normal file
@@ -0,0 +1,146 @@
|
||||
model:
|
||||
target: sgm.models.diffusion.DiffusionEngine
|
||||
params:
|
||||
scale_factor: 0.18215
|
||||
disable_first_stage_autocast: True
|
||||
ckpt_path: checkpoints/svd.safetensors
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||
params:
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||
params:
|
||||
adm_in_channels: 768
|
||||
num_classes: sequential
|
||||
use_checkpoint: True
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [4, 2, 1]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
num_head_channels: 64
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
extra_ff_mix_layer: True
|
||||
use_spatial_context: True
|
||||
merge_strategy: learned_with_images
|
||||
video_kernel_size: [3, 1, 1]
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: False
|
||||
input_key: cond_frames_without_noise
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||
params:
|
||||
n_cond_frames: 1
|
||||
n_copies: 1
|
||||
open_clip_embedding_config:
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
|
||||
- input_key: fps_id
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
- input_key: motion_bucket_id
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
- input_key: cond_frames
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
params:
|
||||
disable_encoder_autocast: True
|
||||
n_cond_frames: 1
|
||||
n_copies: 1
|
||||
is_ae: True
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
- input_key: cond_aug
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencodingEngine
|
||||
params:
|
||||
loss_config:
|
||||
target: torch.nn.Identity
|
||||
regularizer_config:
|
||||
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||
encoder_config:
|
||||
target: sgm.modules.diffusionmodules.model.Encoder
|
||||
params:
|
||||
attn_type: vanilla
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
decoder_config:
|
||||
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
|
||||
params:
|
||||
attn_type: vanilla
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
video_kernel_size: [3, 1, 1]
|
||||
|
||||
sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||
params:
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||
params:
|
||||
sigma_max: 700.0
|
||||
|
||||
guider_config:
|
||||
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
||||
params:
|
||||
max_scale: 2.5
|
||||
min_scale: 1.0
|
||||
129
scripts/sampling/configs/svd_image_decoder.yaml
Normal file
129
scripts/sampling/configs/svd_image_decoder.yaml
Normal file
@@ -0,0 +1,129 @@
|
||||
model:
|
||||
target: sgm.models.diffusion.DiffusionEngine
|
||||
params:
|
||||
scale_factor: 0.18215
|
||||
disable_first_stage_autocast: True
|
||||
ckpt_path: checkpoints/svd_image_decoder.safetensors
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||
params:
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||
params:
|
||||
adm_in_channels: 768
|
||||
num_classes: sequential
|
||||
use_checkpoint: True
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [4, 2, 1]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
num_head_channels: 64
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
extra_ff_mix_layer: True
|
||||
use_spatial_context: True
|
||||
merge_strategy: learned_with_images
|
||||
video_kernel_size: [3, 1, 1]
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: False
|
||||
input_key: cond_frames_without_noise
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||
params:
|
||||
n_cond_frames: 1
|
||||
n_copies: 1
|
||||
open_clip_embedding_config:
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
|
||||
- input_key: fps_id
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
- input_key: motion_bucket_id
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
- input_key: cond_frames
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
params:
|
||||
disable_encoder_autocast: True
|
||||
n_cond_frames: 1
|
||||
n_copies: 1
|
||||
is_ae: True
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
- input_key: cond_aug
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||
params:
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||
params:
|
||||
sigma_max: 700.0
|
||||
|
||||
guider_config:
|
||||
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
||||
params:
|
||||
max_scale: 2.5
|
||||
min_scale: 1.0
|
||||
146
scripts/sampling/configs/svd_xt.yaml
Normal file
146
scripts/sampling/configs/svd_xt.yaml
Normal file
@@ -0,0 +1,146 @@
|
||||
model:
|
||||
target: sgm.models.diffusion.DiffusionEngine
|
||||
params:
|
||||
scale_factor: 0.18215
|
||||
disable_first_stage_autocast: True
|
||||
ckpt_path: checkpoints/svd_xt.safetensors
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||
params:
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||
params:
|
||||
adm_in_channels: 768
|
||||
num_classes: sequential
|
||||
use_checkpoint: True
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [4, 2, 1]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
num_head_channels: 64
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
extra_ff_mix_layer: True
|
||||
use_spatial_context: True
|
||||
merge_strategy: learned_with_images
|
||||
video_kernel_size: [3, 1, 1]
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: False
|
||||
input_key: cond_frames_without_noise
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||
params:
|
||||
n_cond_frames: 1
|
||||
n_copies: 1
|
||||
open_clip_embedding_config:
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
|
||||
- input_key: fps_id
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
- input_key: motion_bucket_id
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
- input_key: cond_frames
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
params:
|
||||
disable_encoder_autocast: True
|
||||
n_cond_frames: 1
|
||||
n_copies: 1
|
||||
is_ae: True
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
- input_key: cond_aug
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencodingEngine
|
||||
params:
|
||||
loss_config:
|
||||
target: torch.nn.Identity
|
||||
regularizer_config:
|
||||
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||
encoder_config:
|
||||
target: sgm.modules.diffusionmodules.model.Encoder
|
||||
params:
|
||||
attn_type: vanilla
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
decoder_config:
|
||||
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
|
||||
params:
|
||||
attn_type: vanilla
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
video_kernel_size: [3, 1, 1]
|
||||
|
||||
sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||
params:
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||
params:
|
||||
sigma_max: 700.0
|
||||
|
||||
guider_config:
|
||||
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
||||
params:
|
||||
max_scale: 3.0
|
||||
min_scale: 1.5
|
||||
129
scripts/sampling/configs/svd_xt_image_decoder.yaml
Normal file
129
scripts/sampling/configs/svd_xt_image_decoder.yaml
Normal file
@@ -0,0 +1,129 @@
|
||||
model:
|
||||
target: sgm.models.diffusion.DiffusionEngine
|
||||
params:
|
||||
scale_factor: 0.18215
|
||||
disable_first_stage_autocast: True
|
||||
ckpt_path: checkpoints/svd_xt_image_decoder.safetensors
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||
params:
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||
params:
|
||||
adm_in_channels: 768
|
||||
num_classes: sequential
|
||||
use_checkpoint: True
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [4, 2, 1]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [1, 2, 4, 4]
|
||||
num_head_channels: 64
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
extra_ff_mix_layer: True
|
||||
use_spatial_context: True
|
||||
merge_strategy: learned_with_images
|
||||
video_kernel_size: [3, 1, 1]
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: False
|
||||
input_key: cond_frames_without_noise
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||
params:
|
||||
n_cond_frames: 1
|
||||
n_copies: 1
|
||||
open_clip_embedding_config:
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
|
||||
- input_key: fps_id
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
- input_key: motion_bucket_id
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
- input_key: cond_frames
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||
params:
|
||||
disable_encoder_autocast: True
|
||||
n_cond_frames: 1
|
||||
n_copies: 1
|
||||
is_ae: True
|
||||
encoder_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
- input_key: cond_aug
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||
params:
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||
params:
|
||||
sigma_max: 700.0
|
||||
|
||||
guider_config:
|
||||
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
||||
params:
|
||||
max_scale: 3.0
|
||||
min_scale: 1.5
|
||||
278
scripts/sampling/simple_video_sample.py
Normal file
278
scripts/sampling/simple_video_sample.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import math
|
||||
import os
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from fire import Fire
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
from scripts.util.detection.nsfw_and_watermark_dectection import \
|
||||
DeepFloydDataFiltering
|
||||
from sgm.inference.helpers import embed_watermark
|
||||
from sgm.util import default, instantiate_from_config
|
||||
|
||||
|
||||
def sample(
|
||||
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
|
||||
num_frames: Optional[int] = None,
|
||||
num_steps: Optional[int] = None,
|
||||
version: str = "svd",
|
||||
fps_id: int = 6,
|
||||
motion_bucket_id: int = 127,
|
||||
cond_aug: float = 0.02,
|
||||
seed: int = 23,
|
||||
decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||
device: str = "cuda",
|
||||
output_folder: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
|
||||
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
|
||||
"""
|
||||
|
||||
if version == "svd":
|
||||
num_frames = default(num_frames, 14)
|
||||
num_steps = default(num_steps, 25)
|
||||
output_folder = default(output_folder, "outputs/simple_video_sample/svd/")
|
||||
model_config = "scripts/sampling/configs/svd.yaml"
|
||||
elif version == "svd_xt":
|
||||
num_frames = default(num_frames, 25)
|
||||
num_steps = default(num_steps, 30)
|
||||
output_folder = default(output_folder, "outputs/simple_video_sample/svd_xt/")
|
||||
model_config = "scripts/sampling/configs/svd_xt.yaml"
|
||||
elif version == "svd_image_decoder":
|
||||
num_frames = default(num_frames, 14)
|
||||
num_steps = default(num_steps, 25)
|
||||
output_folder = default(
|
||||
output_folder, "outputs/simple_video_sample/svd_image_decoder/"
|
||||
)
|
||||
model_config = "scripts/sampling/configs/svd_image_decoder.yaml"
|
||||
elif version == "svd_xt_image_decoder":
|
||||
num_frames = default(num_frames, 25)
|
||||
num_steps = default(num_steps, 30)
|
||||
output_folder = default(
|
||||
output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/"
|
||||
)
|
||||
model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml"
|
||||
else:
|
||||
raise ValueError(f"Version {version} does not exist.")
|
||||
|
||||
model, filter = load_model(
|
||||
model_config,
|
||||
device,
|
||||
num_frames,
|
||||
num_steps,
|
||||
)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
path = Path(input_path)
|
||||
all_img_paths = []
|
||||
if path.is_file():
|
||||
if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
|
||||
all_img_paths = [input_path]
|
||||
else:
|
||||
raise ValueError("Path is not valid image file.")
|
||||
elif path.is_dir():
|
||||
all_img_paths = sorted(
|
||||
[
|
||||
f
|
||||
for f in path.iterdir()
|
||||
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
|
||||
]
|
||||
)
|
||||
if len(all_img_paths) == 0:
|
||||
raise ValueError("Folder does not contain any images.")
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
for input_img_path in all_img_paths:
|
||||
with Image.open(input_img_path) as image:
|
||||
if image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
w, h = image.size
|
||||
|
||||
if h % 64 != 0 or w % 64 != 0:
|
||||
width, height = map(lambda x: x - x % 64, (w, h))
|
||||
image = image.resize((width, height))
|
||||
print(
|
||||
f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
|
||||
)
|
||||
|
||||
image = ToTensor()(image)
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
image = image.unsqueeze(0).to(device)
|
||||
H, W = image.shape[2:]
|
||||
assert image.shape[1] == 3
|
||||
F = 8
|
||||
C = 4
|
||||
shape = (num_frames, C, H // F, W // F)
|
||||
if (H, W) != (576, 1024):
|
||||
print(
|
||||
"WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
|
||||
)
|
||||
if motion_bucket_id > 255:
|
||||
print(
|
||||
"WARNING: High motion bucket! This may lead to suboptimal performance."
|
||||
)
|
||||
|
||||
if fps_id < 5:
|
||||
print("WARNING: Small fps value! This may lead to suboptimal performance.")
|
||||
|
||||
if fps_id > 30:
|
||||
print("WARNING: Large fps value! This may lead to suboptimal performance.")
|
||||
|
||||
value_dict = {}
|
||||
value_dict["motion_bucket_id"] = motion_bucket_id
|
||||
value_dict["fps_id"] = fps_id
|
||||
value_dict["cond_aug"] = cond_aug
|
||||
value_dict["cond_frames_without_noise"] = image
|
||||
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
|
||||
value_dict["cond_aug"] = cond_aug
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.autocast(device):
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
[1, num_frames],
|
||||
T=num_frames,
|
||||
device=device,
|
||||
)
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=[
|
||||
"cond_frames",
|
||||
"cond_frames_without_noise",
|
||||
],
|
||||
)
|
||||
|
||||
for k in ["crossattn", "concat"]:
|
||||
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
|
||||
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
|
||||
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
|
||||
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
|
||||
|
||||
randn = torch.randn(shape, device=device)
|
||||
|
||||
additional_model_inputs = {}
|
||||
additional_model_inputs["image_only_indicator"] = torch.zeros(
|
||||
2, num_frames
|
||||
).to(device)
|
||||
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
|
||||
|
||||
def denoiser(input, sigma, c):
|
||||
return model.denoiser(
|
||||
model.model, input, sigma, c, **additional_model_inputs
|
||||
)
|
||||
|
||||
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
|
||||
model.en_and_decode_n_samples_a_time = decoding_t
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
|
||||
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
|
||||
writer = cv2.VideoWriter(
|
||||
video_path,
|
||||
cv2.VideoWriter_fourcc(*"MP4V"),
|
||||
fps_id + 1,
|
||||
(samples.shape[-1], samples.shape[-2]),
|
||||
)
|
||||
|
||||
samples = embed_watermark(samples)
|
||||
samples = filter(samples)
|
||||
vid = (
|
||||
(rearrange(samples, "t c h w -> t h w c") * 255)
|
||||
.cpu()
|
||||
.numpy()
|
||||
.astype(np.uint8)
|
||||
)
|
||||
for frame in vid:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||
writer.write(frame)
|
||||
writer.release()
|
||||
|
||||
|
||||
def get_unique_embedder_keys_from_conditioner(conditioner):
|
||||
return list(set([x.input_key for x in conditioner.embedders]))
|
||||
|
||||
|
||||
def get_batch(keys, value_dict, N, T, device):
|
||||
batch = {}
|
||||
batch_uc = {}
|
||||
|
||||
for key in keys:
|
||||
if key == "fps_id":
|
||||
batch[key] = (
|
||||
torch.tensor([value_dict["fps_id"]])
|
||||
.to(device)
|
||||
.repeat(int(math.prod(N)))
|
||||
)
|
||||
elif key == "motion_bucket_id":
|
||||
batch[key] = (
|
||||
torch.tensor([value_dict["motion_bucket_id"]])
|
||||
.to(device)
|
||||
.repeat(int(math.prod(N)))
|
||||
)
|
||||
elif key == "cond_aug":
|
||||
batch[key] = repeat(
|
||||
torch.tensor([value_dict["cond_aug"]]).to(device),
|
||||
"1 -> b",
|
||||
b=math.prod(N),
|
||||
)
|
||||
elif key == "cond_frames":
|
||||
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
|
||||
elif key == "cond_frames_without_noise":
|
||||
batch[key] = repeat(
|
||||
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
|
||||
)
|
||||
else:
|
||||
batch[key] = value_dict[key]
|
||||
|
||||
if T is not None:
|
||||
batch["num_video_frames"] = T
|
||||
|
||||
for key in batch.keys():
|
||||
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
||||
batch_uc[key] = torch.clone(batch[key])
|
||||
return batch, batch_uc
|
||||
|
||||
|
||||
def load_model(
|
||||
config: str,
|
||||
device: str,
|
||||
num_frames: int,
|
||||
num_steps: int,
|
||||
):
|
||||
config = OmegaConf.load(config)
|
||||
if device == "cuda":
|
||||
config.model.params.conditioner_config.params.emb_models[
|
||||
0
|
||||
].params.open_clip_embedding_config.params.init_device = device
|
||||
|
||||
config.model.params.sampler_config.params.num_steps = num_steps
|
||||
config.model.params.sampler_config.params.guider_config.params.num_frames = (
|
||||
num_frames
|
||||
)
|
||||
if device == "cuda":
|
||||
with torch.device(device):
|
||||
model = instantiate_from_config(config.model).to(device).eval()
|
||||
else:
|
||||
model = instantiate_from_config(config.model).to(device).eval()
|
||||
|
||||
filter = DeepFloydDataFiltering(verbose=False, device=device)
|
||||
return model, filter
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Fire(sample)
|
||||
@@ -1,10 +1,10 @@
|
||||
import torch
|
||||
import einops
|
||||
from torch.backends.cuda import SDPBackend
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.benchmark as benchmark
|
||||
from torch.backends.cuda import SDPBackend
|
||||
|
||||
from sgm.modules.attention import SpatialTransformer, BasicTransformerBlock
|
||||
from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer
|
||||
|
||||
|
||||
def benchmark_attn():
|
||||
|
||||
@@ -37,10 +37,13 @@ def clip_process_images(images: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
class DeepFloydDataFiltering(object):
|
||||
def __init__(self, verbose: bool = False):
|
||||
def __init__(
|
||||
self, verbose: bool = False, device: torch.device = torch.device("cpu")
|
||||
):
|
||||
super().__init__()
|
||||
self.verbose = verbose
|
||||
self.clip_model, _ = clip.load("ViT-L/14", device="cpu")
|
||||
self._device = None
|
||||
self.clip_model, _ = clip.load("ViT-L/14", device=device)
|
||||
self.clip_model.eval()
|
||||
|
||||
self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
|
||||
@@ -54,7 +57,9 @@ class DeepFloydDataFiltering(object):
|
||||
@torch.inference_mode()
|
||||
def __call__(self, images: torch.Tensor) -> torch.Tensor:
|
||||
imgs = clip_process_images(images)
|
||||
image_features = self.clip_model.encode_image(imgs.to("cpu"))
|
||||
if self._device is None:
|
||||
self._device = next(p for p in self.clip_model.parameters()).device
|
||||
image_features = self.clip_model.encode_image(imgs.to(self._device))
|
||||
image_features = image_features.detach().cpu().numpy().astype(np.float16)
|
||||
p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
|
||||
w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)
|
||||
|
||||
@@ -1,23 +1,20 @@
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
from omegaconf import OmegaConf
|
||||
import pathlib
|
||||
from sgm.inference.helpers import (
|
||||
do_sample,
|
||||
do_img2img,
|
||||
Img2ImgDiscretizationWrapper,
|
||||
)
|
||||
from sgm.modules.diffusionmodules.sampling import (
|
||||
EulerEDMSampler,
|
||||
HeunEDMSampler,
|
||||
EulerAncestralSampler,
|
||||
DPMPP2SAncestralSampler,
|
||||
DPMPP2MSampler,
|
||||
LinearMultistepSampler,
|
||||
)
|
||||
from sgm.util import load_model_from_config
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
|
||||
do_sample)
|
||||
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
|
||||
DPMPP2SAncestralSampler,
|
||||
EulerAncestralSampler,
|
||||
EulerEDMSampler,
|
||||
HeunEDMSampler,
|
||||
LinearMultistepSampler)
|
||||
from sgm.util import load_model_from_config
|
||||
|
||||
|
||||
class ModelArchitecture(str, Enum):
|
||||
SD_2_1 = "stable-diffusion-v2-1"
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import os
|
||||
from typing import Union, List, Optional
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from einops import rearrange
|
||||
from imwatermark import WatermarkEncoder
|
||||
from omegaconf import ListConfig
|
||||
from PIL import Image
|
||||
from torch import autocast
|
||||
|
||||
from sgm.util import append_dims
|
||||
@@ -20,17 +20,16 @@ class WatermarkEmbedder:
|
||||
self.encoder = WatermarkEncoder()
|
||||
self.encoder.set_watermark("bits", self.watermark)
|
||||
|
||||
def __call__(self, image: torch.Tensor):
|
||||
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Adds a predefined watermark to the input image
|
||||
|
||||
Args:
|
||||
image: ([N,] B, C, H, W) in range [0, 1]
|
||||
image: ([N,] B, RGB, H, W) in range [0, 1]
|
||||
|
||||
Returns:
|
||||
same as input but watermarked
|
||||
"""
|
||||
# watermarking libary expects input as cv2 BGR format
|
||||
squeeze = len(image.shape) == 4
|
||||
if squeeze:
|
||||
image = image[None, ...]
|
||||
@@ -39,6 +38,7 @@ class WatermarkEmbedder:
|
||||
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
||||
).numpy()[:, :, :, ::-1]
|
||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||
# watermarking libary expects input as cv2 BGR format
|
||||
for k in range(image_np.shape[0]):
|
||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||
image = torch.from_numpy(
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from omegaconf import ListConfig
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
from ..modules.diffusionmodules.model import Decoder, Encoder
|
||||
from ..modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
from ..modules.autoencoding.regularizers import AbstractRegularizer
|
||||
from ..modules.ema import LitEma
|
||||
from ..util import default, get_obj_from_str, instantiate_from_config
|
||||
from ..util import (default, get_nested_attribute, get_obj_from_str,
|
||||
instantiate_from_config)
|
||||
|
||||
logpy = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractAutoencoder(pl.LightningModule):
|
||||
@@ -27,10 +31,9 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
ema_decay: Union[None, float] = None,
|
||||
monitor: Union[None, str] = None,
|
||||
input_key: str = "jpg",
|
||||
ckpt_path: Union[None, str] = None,
|
||||
ignore_keys: Union[Tuple, list, ListConfig] = (),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_key = input_key
|
||||
self.use_ema = ema_decay is not None
|
||||
if monitor is not None:
|
||||
@@ -38,38 +41,21 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self, decay=ema_decay)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||
self.automatic_optimization = False
|
||||
|
||||
def init_from_ckpt(
|
||||
self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
|
||||
) -> None:
|
||||
if path.endswith("ckpt"):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
elif path.endswith("safetensors"):
|
||||
sd = load_safetensors(path)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if re.match(ik, k):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(
|
||||
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
||||
)
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
if len(unexpected) > 0:
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
||||
if ckpt is None:
|
||||
return
|
||||
if isinstance(ckpt, str):
|
||||
ckpt = {
|
||||
"target": "sgm.modules.checkpoint.CheckpointEngine",
|
||||
"params": {"ckpt_path": ckpt},
|
||||
}
|
||||
engine = instantiate_from_config(ckpt)
|
||||
engine(self)
|
||||
|
||||
@abstractmethod
|
||||
def get_input(self, batch) -> Any:
|
||||
@@ -86,14 +72,14 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
logpy.info(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
logpy.info(f"{context}: Restored training weights")
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, *args, **kwargs) -> torch.Tensor:
|
||||
@@ -104,7 +90,7 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
raise NotImplementedError("decode()-method of abstract base class called")
|
||||
|
||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||
print(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
return get_obj_from_str(cfg["target"])(
|
||||
params, lr=lr, **cfg.get("params", dict())
|
||||
)
|
||||
@@ -129,196 +115,435 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
regularizer_config: Dict,
|
||||
optimizer_config: Union[Dict, None] = None,
|
||||
lr_g_factor: float = 1.0,
|
||||
trainable_ae_params: Optional[List[List[str]]] = None,
|
||||
ae_optimizer_args: Optional[List[dict]] = None,
|
||||
trainable_disc_params: Optional[List[List[str]]] = None,
|
||||
disc_optimizer_args: Optional[List[dict]] = None,
|
||||
disc_start_iter: int = 0,
|
||||
diff_boost_factor: float = 3.0,
|
||||
ckpt_engine: Union[None, str, dict] = None,
|
||||
ckpt_path: Optional[str] = None,
|
||||
additional_decode_keys: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
# todo: add options to freeze encoder/decoder
|
||||
self.encoder = instantiate_from_config(encoder_config)
|
||||
self.decoder = instantiate_from_config(decoder_config)
|
||||
self.loss = instantiate_from_config(loss_config)
|
||||
self.regularization = instantiate_from_config(regularizer_config)
|
||||
self.automatic_optimization = False # pytorch lightning
|
||||
|
||||
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
||||
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
||||
self.loss: torch.nn.Module = instantiate_from_config(loss_config)
|
||||
self.regularization: AbstractRegularizer = instantiate_from_config(
|
||||
regularizer_config
|
||||
)
|
||||
self.optimizer_config = default(
|
||||
optimizer_config, {"target": "torch.optim.Adam"}
|
||||
)
|
||||
self.diff_boost_factor = diff_boost_factor
|
||||
self.disc_start_iter = disc_start_iter
|
||||
self.lr_g_factor = lr_g_factor
|
||||
self.trainable_ae_params = trainable_ae_params
|
||||
if self.trainable_ae_params is not None:
|
||||
self.ae_optimizer_args = default(
|
||||
ae_optimizer_args,
|
||||
[{} for _ in range(len(self.trainable_ae_params))],
|
||||
)
|
||||
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
|
||||
else:
|
||||
self.ae_optimizer_args = [{}] # makes type consitent
|
||||
|
||||
self.trainable_disc_params = trainable_disc_params
|
||||
if self.trainable_disc_params is not None:
|
||||
self.disc_optimizer_args = default(
|
||||
disc_optimizer_args,
|
||||
[{} for _ in range(len(self.trainable_disc_params))],
|
||||
)
|
||||
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
|
||||
else:
|
||||
self.disc_optimizer_args = [{}] # makes type consitent
|
||||
|
||||
if ckpt_path is not None:
|
||||
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
|
||||
logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
|
||||
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
||||
self.additional_decode_keys = set(default(additional_decode_keys, []))
|
||||
|
||||
def get_input(self, batch: Dict) -> torch.Tensor:
|
||||
# assuming unified data format, dataloader returns a dict.
|
||||
# image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
|
||||
# image tensors should be scaled to -1 ... 1 and in channels-first
|
||||
# format (e.g., bchw instead if bhwc)
|
||||
return batch[self.input_key]
|
||||
|
||||
def get_autoencoder_params(self) -> list:
|
||||
params = (
|
||||
list(self.encoder.parameters())
|
||||
+ list(self.decoder.parameters())
|
||||
+ list(self.regularization.get_trainable_parameters())
|
||||
+ list(self.loss.get_trainable_autoencoder_parameters())
|
||||
)
|
||||
params = []
|
||||
if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
|
||||
params += list(self.loss.get_trainable_autoencoder_parameters())
|
||||
if hasattr(self.regularization, "get_trainable_parameters"):
|
||||
params += list(self.regularization.get_trainable_parameters())
|
||||
params = params + list(self.encoder.parameters())
|
||||
params = params + list(self.decoder.parameters())
|
||||
return params
|
||||
|
||||
def get_discriminator_params(self) -> list:
|
||||
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
||||
if hasattr(self.loss, "get_trainable_parameters"):
|
||||
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
||||
else:
|
||||
params = []
|
||||
return params
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.get_last_layer()
|
||||
|
||||
def encode(self, x: Any, return_reg_log: bool = False) -> Any:
|
||||
def encode(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
return_reg_log: bool = False,
|
||||
unregularized: bool = False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
z = self.encoder(x)
|
||||
if unregularized:
|
||||
return z, dict()
|
||||
z, reg_log = self.regularization(z)
|
||||
if return_reg_log:
|
||||
return z, reg_log
|
||||
return z
|
||||
|
||||
def decode(self, z: Any) -> torch.Tensor:
|
||||
x = self.decoder(z)
|
||||
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
x = self.decoder(z, **kwargs)
|
||||
return x
|
||||
|
||||
def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def forward(
|
||||
self, x: torch.Tensor, **additional_decode_kwargs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
z, reg_log = self.encode(x, return_reg_log=True)
|
||||
dec = self.decode(z)
|
||||
dec = self.decode(z, **additional_decode_kwargs)
|
||||
return z, dec, reg_log
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
|
||||
def inner_training_step(
|
||||
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
||||
) -> torch.Tensor:
|
||||
x = self.get_input(batch)
|
||||
z, xrec, regularization_log = self(x)
|
||||
additional_decode_kwargs = {
|
||||
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
||||
}
|
||||
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
||||
if hasattr(self.loss, "forward_keys"):
|
||||
extra_info = {
|
||||
"z": z,
|
||||
"optimizer_idx": optimizer_idx,
|
||||
"global_step": self.global_step,
|
||||
"last_layer": self.get_last_layer(),
|
||||
"split": "train",
|
||||
"regularization_log": regularization_log,
|
||||
"autoencoder": self,
|
||||
}
|
||||
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
||||
else:
|
||||
extra_info = dict()
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
regularization_log,
|
||||
x,
|
||||
xrec,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train",
|
||||
)
|
||||
out_loss = self.loss(x, xrec, **extra_info)
|
||||
if isinstance(out_loss, tuple):
|
||||
aeloss, log_dict_ae = out_loss
|
||||
else:
|
||||
# simple loss function
|
||||
aeloss = out_loss
|
||||
log_dict_ae = {"train/loss/rec": aeloss.detach()}
|
||||
|
||||
self.log_dict(
|
||||
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
||||
log_dict_ae,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
sync_dist=False,
|
||||
)
|
||||
self.log(
|
||||
"loss",
|
||||
aeloss.mean().detach(),
|
||||
prog_bar=True,
|
||||
logger=False,
|
||||
on_epoch=False,
|
||||
on_step=True,
|
||||
)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
elif optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(
|
||||
regularization_log,
|
||||
x,
|
||||
xrec,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train",
|
||||
)
|
||||
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
||||
# -> discriminator always needs to return a tuple
|
||||
self.log_dict(
|
||||
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
||||
)
|
||||
return discloss
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
|
||||
|
||||
def validation_step(self, batch, batch_idx) -> Dict:
|
||||
def training_step(self, batch: dict, batch_idx: int):
|
||||
opts = self.optimizers()
|
||||
if not isinstance(opts, list):
|
||||
# Non-adversarial case
|
||||
opts = [opts]
|
||||
optimizer_idx = batch_idx % len(opts)
|
||||
if self.global_step < self.disc_start_iter:
|
||||
optimizer_idx = 0
|
||||
opt = opts[optimizer_idx]
|
||||
opt.zero_grad()
|
||||
with opt.toggle_model():
|
||||
loss = self.inner_training_step(
|
||||
batch, batch_idx, optimizer_idx=optimizer_idx
|
||||
)
|
||||
self.manual_backward(loss)
|
||||
opt.step()
|
||||
|
||||
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
||||
log_dict.update(log_dict_ema)
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
|
||||
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
|
||||
x = self.get_input(batch)
|
||||
|
||||
z, xrec, regularization_log = self(x)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
regularization_log,
|
||||
x,
|
||||
xrec,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val" + postfix,
|
||||
if hasattr(self.loss, "forward_keys"):
|
||||
extra_info = {
|
||||
"z": z,
|
||||
"optimizer_idx": 0,
|
||||
"global_step": self.global_step,
|
||||
"last_layer": self.get_last_layer(),
|
||||
"split": "val" + postfix,
|
||||
"regularization_log": regularization_log,
|
||||
"autoencoder": self,
|
||||
}
|
||||
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
||||
else:
|
||||
extra_info = dict()
|
||||
out_loss = self.loss(x, xrec, **extra_info)
|
||||
if isinstance(out_loss, tuple):
|
||||
aeloss, log_dict_ae = out_loss
|
||||
else:
|
||||
# simple loss function
|
||||
aeloss = out_loss
|
||||
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
|
||||
full_log_dict = log_dict_ae
|
||||
|
||||
if "optimizer_idx" in extra_info:
|
||||
extra_info["optimizer_idx"] = 1
|
||||
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
||||
full_log_dict.update(log_dict_disc)
|
||||
self.log(
|
||||
f"val{postfix}/loss/rec",
|
||||
log_dict_ae[f"val{postfix}/loss/rec"],
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log_dict(full_log_dict, sync_dist=True)
|
||||
return full_log_dict
|
||||
|
||||
discloss, log_dict_disc = self.loss(
|
||||
regularization_log,
|
||||
x,
|
||||
xrec,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val" + postfix,
|
||||
)
|
||||
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
|
||||
log_dict_ae.update(log_dict_disc)
|
||||
self.log_dict(log_dict_ae)
|
||||
return log_dict_ae
|
||||
|
||||
def configure_optimizers(self) -> Any:
|
||||
ae_params = self.get_autoencoder_params()
|
||||
disc_params = self.get_discriminator_params()
|
||||
def get_param_groups(
|
||||
self, parameter_names: List[List[str]], optimizer_args: List[dict]
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
groups = []
|
||||
num_params = 0
|
||||
for names, args in zip(parameter_names, optimizer_args):
|
||||
params = []
|
||||
for pattern_ in names:
|
||||
pattern_params = []
|
||||
pattern = re.compile(pattern_)
|
||||
for p_name, param in self.named_parameters():
|
||||
if re.match(pattern, p_name):
|
||||
pattern_params.append(param)
|
||||
num_params += param.numel()
|
||||
if len(pattern_params) == 0:
|
||||
logpy.warn(f"Did not find parameters for pattern {pattern_}")
|
||||
params.extend(pattern_params)
|
||||
groups.append({"params": params, **args})
|
||||
return groups, num_params
|
||||
|
||||
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
|
||||
if self.trainable_ae_params is None:
|
||||
ae_params = self.get_autoencoder_params()
|
||||
else:
|
||||
ae_params, num_ae_params = self.get_param_groups(
|
||||
self.trainable_ae_params, self.ae_optimizer_args
|
||||
)
|
||||
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
||||
if self.trainable_disc_params is None:
|
||||
disc_params = self.get_discriminator_params()
|
||||
else:
|
||||
disc_params, num_disc_params = self.get_param_groups(
|
||||
self.trainable_disc_params, self.disc_optimizer_args
|
||||
)
|
||||
logpy.info(
|
||||
f"Number of trainable discriminator parameters: {num_disc_params:,}"
|
||||
)
|
||||
opt_ae = self.instantiate_optimizer_from_config(
|
||||
ae_params,
|
||||
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
||||
self.optimizer_config,
|
||||
)
|
||||
opt_disc = self.instantiate_optimizer_from_config(
|
||||
disc_params, self.learning_rate, self.optimizer_config
|
||||
)
|
||||
opts = [opt_ae]
|
||||
if len(disc_params) > 0:
|
||||
opt_disc = self.instantiate_optimizer_from_config(
|
||||
disc_params, self.learning_rate, self.optimizer_config
|
||||
)
|
||||
opts.append(opt_disc)
|
||||
|
||||
return [opt_ae, opt_disc], []
|
||||
return opts
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch: Dict, **kwargs) -> Dict:
|
||||
def log_images(
|
||||
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
||||
) -> dict:
|
||||
log = dict()
|
||||
additional_decode_kwargs = {}
|
||||
x = self.get_input(batch)
|
||||
_, xrec, _ = self(x)
|
||||
additional_decode_kwargs.update(
|
||||
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
||||
)
|
||||
|
||||
_, xrec, _ = self(x, **additional_decode_kwargs)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
|
||||
diff.clamp_(0, 1.0)
|
||||
log["diff"] = 2.0 * diff - 1.0
|
||||
# diff_boost shows location of small errors, by boosting their
|
||||
# brightness.
|
||||
log["diff_boost"] = (
|
||||
2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
|
||||
)
|
||||
if hasattr(self.loss, "log_images"):
|
||||
log.update(self.loss.log_images(x, xrec))
|
||||
with self.ema_scope():
|
||||
_, xrec_ema, _ = self(x)
|
||||
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
||||
diff_ema.clamp_(0, 1.0)
|
||||
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
||||
log["diff_boost_ema"] = (
|
||||
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
||||
)
|
||||
if additional_log_kwargs:
|
||||
additional_decode_kwargs.update(additional_log_kwargs)
|
||||
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
||||
log_str = "reconstructions-" + "-".join(
|
||||
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
|
||||
)
|
||||
log[log_str] = xrec_add
|
||||
return log
|
||||
|
||||
|
||||
class AutoencoderKL(AutoencodingEngine):
|
||||
class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
def __init__(self, embed_dim: int, **kwargs):
|
||||
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
||||
ddconfig = kwargs.pop("ddconfig")
|
||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||
ignore_keys = kwargs.pop("ignore_keys", ())
|
||||
ckpt_engine = kwargs.pop("ckpt_engine", None)
|
||||
super().__init__(
|
||||
encoder_config={"target": "torch.nn.Identity"},
|
||||
decoder_config={"target": "torch.nn.Identity"},
|
||||
regularizer_config={"target": "torch.nn.Identity"},
|
||||
loss_config=kwargs.pop("lossconfig"),
|
||||
encoder_config={
|
||||
"target": "sgm.modules.diffusionmodules.model.Encoder",
|
||||
"params": ddconfig,
|
||||
},
|
||||
decoder_config={
|
||||
"target": "sgm.modules.diffusionmodules.model.Decoder",
|
||||
"params": ddconfig,
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
assert ddconfig["double_z"]
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
||||
self.quant_conv = torch.nn.Conv2d(
|
||||
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
||||
(1 + ddconfig["double_z"]) * embed_dim,
|
||||
1,
|
||||
)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
||||
|
||||
def encode(self, x):
|
||||
assert (
|
||||
not self.training
|
||||
), f"{self.__class__.__name__} only supports inference currently"
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
def get_autoencoder_params(self) -> list:
|
||||
params = super().get_autoencoder_params()
|
||||
return params
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_reg_log: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
if self.max_batch_size is None:
|
||||
z = self.encoder(x)
|
||||
z = self.quant_conv(z)
|
||||
else:
|
||||
N = x.shape[0]
|
||||
bs = self.max_batch_size
|
||||
n_batches = int(math.ceil(N / bs))
|
||||
z = list()
|
||||
for i_batch in range(n_batches):
|
||||
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
||||
z_batch = self.quant_conv(z_batch)
|
||||
z.append(z_batch)
|
||||
z = torch.cat(z, 0)
|
||||
|
||||
z, reg_log = self.regularization(z)
|
||||
if return_reg_log:
|
||||
return z, reg_log
|
||||
return z
|
||||
|
||||
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
||||
if self.max_batch_size is None:
|
||||
dec = self.post_quant_conv(z)
|
||||
dec = self.decoder(dec, **decoder_kwargs)
|
||||
else:
|
||||
N = z.shape[0]
|
||||
bs = self.max_batch_size
|
||||
n_batches = int(math.ceil(N / bs))
|
||||
dec = list()
|
||||
for i_batch in range(n_batches):
|
||||
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
||||
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
||||
dec.append(dec_batch)
|
||||
dec = torch.cat(dec, 0)
|
||||
|
||||
def decode(self, z, **decoder_kwargs):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z, **decoder_kwargs)
|
||||
return dec
|
||||
|
||||
|
||||
class AutoencoderKLInferenceWrapper(AutoencoderKL):
|
||||
def encode(self, x):
|
||||
return super().encode(x).sample()
|
||||
class AutoencoderKL(AutoencodingEngineLegacy):
|
||||
def __init__(self, **kwargs):
|
||||
if "lossconfig" in kwargs:
|
||||
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
||||
super().__init__(
|
||||
regularizer_config={
|
||||
"target": (
|
||||
"sgm.modules.autoencoding.regularizers"
|
||||
".DiagonalGaussianRegularizer"
|
||||
)
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
n_embed: int,
|
||||
sane_index_shape: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if "lossconfig" in kwargs:
|
||||
logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
|
||||
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
||||
super().__init__(
|
||||
regularizer_config={
|
||||
"target": (
|
||||
"sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
|
||||
),
|
||||
"params": {
|
||||
"n_e": n_embed,
|
||||
"e_dim": embed_dim,
|
||||
"sane_index_shape": sane_index_shape,
|
||||
},
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class IdentityFirstStage(AbstractAutoencoder):
|
||||
@@ -333,3 +558,58 @@ class IdentityFirstStage(AbstractAutoencoder):
|
||||
|
||||
def decode(self, x: Any, *args, **kwargs) -> Any:
|
||||
return x
|
||||
|
||||
|
||||
class AEIntegerWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
|
||||
regularization_key: str = "regularization",
|
||||
encoder_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
assert hasattr(model, "encode") and hasattr(
|
||||
model, "decode"
|
||||
), "Need AE interface"
|
||||
self.regularization = get_nested_attribute(model, regularization_key)
|
||||
self.shape = shape
|
||||
self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
|
||||
|
||||
def encode(self, x) -> torch.Tensor:
|
||||
assert (
|
||||
not self.training
|
||||
), f"{self.__class__.__name__} only supports inference currently"
|
||||
_, log = self.model.encode(x, **self.encoder_kwargs)
|
||||
assert isinstance(log, dict)
|
||||
inds = log["min_encoding_indices"]
|
||||
return rearrange(inds, "b ... -> b (...)")
|
||||
|
||||
def decode(
|
||||
self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
|
||||
) -> torch.Tensor:
|
||||
# expect inds shape (b, s) with s = h*w
|
||||
shape = default(shape, self.shape) # Optional[(h, w)]
|
||||
if shape is not None:
|
||||
assert len(shape) == 2, f"Unhandeled shape {shape}"
|
||||
inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
|
||||
h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
|
||||
h = rearrange(h, "b h w c -> b c h w")
|
||||
return self.model.decode(h)
|
||||
|
||||
|
||||
class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
|
||||
def __init__(self, **kwargs):
|
||||
if "lossconfig" in kwargs:
|
||||
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
||||
super().__init__(
|
||||
regularizer_config={
|
||||
"target": (
|
||||
"sgm.modules.autoencoding.regularizers"
|
||||
".DiagonalGaussianRegularizer"
|
||||
),
|
||||
"params": {"sample": False},
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
@@ -8,15 +9,11 @@ from safetensors.torch import load_file as load_safetensors
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from ..modules import UNCONDITIONAL_CONFIG
|
||||
from ..modules.autoencoding.temporal_ae import VideoDecoder
|
||||
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
||||
from ..modules.ema import LitEma
|
||||
from ..util import (
|
||||
default,
|
||||
disabled_train,
|
||||
get_obj_from_str,
|
||||
instantiate_from_config,
|
||||
log_txt_as_img,
|
||||
)
|
||||
from ..util import (default, disabled_train, get_obj_from_str,
|
||||
instantiate_from_config, log_txt_as_img)
|
||||
|
||||
|
||||
class DiffusionEngine(pl.LightningModule):
|
||||
@@ -40,6 +37,7 @@ class DiffusionEngine(pl.LightningModule):
|
||||
log_keys: Union[List, None] = None,
|
||||
no_cond_log: bool = False,
|
||||
compile_model: bool = False,
|
||||
en_and_decode_n_samples_a_time: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.log_keys = log_keys
|
||||
@@ -82,6 +80,8 @@ class DiffusionEngine(pl.LightningModule):
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path)
|
||||
|
||||
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
|
||||
|
||||
def init_from_ckpt(
|
||||
self,
|
||||
path: str,
|
||||
@@ -117,14 +117,35 @@ class DiffusionEngine(pl.LightningModule):
|
||||
@torch.no_grad()
|
||||
def decode_first_stage(self, z):
|
||||
z = 1.0 / self.scale_factor * z
|
||||
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
||||
|
||||
n_rounds = math.ceil(z.shape[0] / n_samples)
|
||||
all_out = []
|
||||
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
||||
out = self.first_stage_model.decode(z)
|
||||
for n in range(n_rounds):
|
||||
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
||||
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
||||
else:
|
||||
kwargs = {}
|
||||
out = self.first_stage_model.decode(
|
||||
z[n * n_samples : (n + 1) * n_samples], **kwargs
|
||||
)
|
||||
all_out.append(out)
|
||||
out = torch.cat(all_out, dim=0)
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_first_stage(self, x):
|
||||
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
|
||||
n_rounds = math.ceil(x.shape[0] / n_samples)
|
||||
all_out = []
|
||||
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
||||
z = self.first_stage_model.encode(x)
|
||||
for n in range(n_rounds):
|
||||
out = self.first_stage_model.encode(
|
||||
x[n * n_samples : (n + 1) * n_samples]
|
||||
)
|
||||
all_out.append(out)
|
||||
z = torch.cat(all_out, dim=0)
|
||||
z = self.scale_factor * z
|
||||
return z
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import math
|
||||
from inspect import isfunction
|
||||
from typing import Any, Optional
|
||||
@@ -7,6 +8,9 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
logpy = logging.getLogger(__name__)
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||
SDP_IS_AVAILABLE = True
|
||||
@@ -36,9 +40,10 @@ else:
|
||||
SDP_IS_AVAILABLE = False
|
||||
sdp_kernel = nullcontext
|
||||
BACKEND_MAP = {}
|
||||
print(
|
||||
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
|
||||
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
|
||||
logpy.warn(
|
||||
f"No SDP backend available, likely because you are running in pytorch "
|
||||
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
|
||||
f"You might want to consider upgrading."
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -48,9 +53,9 @@ try:
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
print("no module 'xformers'. Processing without...")
|
||||
logpy.warn("no module 'xformers'. Processing without...")
|
||||
|
||||
from .diffusionmodules.util import checkpoint
|
||||
# from .diffusionmodules.util import mixed_checkpoint as checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
@@ -146,6 +151,62 @@ class LinearAttention(nn.Module):
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
ATTENTION_MODES = ("xformers", "torch", "math")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: Optional[float] = None,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
attn_mode: str = "xformers",
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
self.attn_mode = attn_mode
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, L, C = x.shape
|
||||
|
||||
qkv = self.qkv(x)
|
||||
if self.attn_mode == "torch":
|
||||
qkv = rearrange(
|
||||
qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
||||
).float()
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
elif self.attn_mode == "xformers":
|
||||
qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
|
||||
x = xformers.ops.memory_efficient_attention(q, k, v)
|
||||
x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
|
||||
elif self.attn_mode == "math":
|
||||
qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
|
||||
else:
|
||||
raise NotImplemented
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
@@ -289,9 +350,10 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
print(
|
||||
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||
f"{heads} heads with a dimension of {dim_head}."
|
||||
logpy.debug(
|
||||
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
|
||||
f"context_dim is {context_dim} and using {heads} heads with a "
|
||||
f"dimension of {dim_head}."
|
||||
)
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
@@ -352,9 +414,29 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
)
|
||||
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
out = xformers.ops.memory_efficient_attention(
|
||||
q, k, v, attn_bias=None, op=self.attention_op
|
||||
)
|
||||
if version.parse(xformers.__version__) >= version.parse("0.0.21"):
|
||||
# NOTE: workaround for
|
||||
# https://github.com/facebookresearch/xformers/issues/845
|
||||
max_bs = 32768
|
||||
N = q.shape[0]
|
||||
n_batches = math.ceil(N / max_bs)
|
||||
out = list()
|
||||
for i_batch in range(n_batches):
|
||||
batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
|
||||
out.append(
|
||||
xformers.ops.memory_efficient_attention(
|
||||
q[batch],
|
||||
k[batch],
|
||||
v[batch],
|
||||
attn_bias=None,
|
||||
op=self.attention_op,
|
||||
)
|
||||
)
|
||||
out = torch.cat(out, 0)
|
||||
else:
|
||||
out = xformers.ops.memory_efficient_attention(
|
||||
q, k, v, attn_bias=None, op=self.attention_op
|
||||
)
|
||||
|
||||
# TODO: Use this directly in the attention operation, as a bias
|
||||
if exists(mask):
|
||||
@@ -393,21 +475,24 @@ class BasicTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
||||
print(
|
||||
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
|
||||
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
|
||||
logpy.warn(
|
||||
f"Attention mode '{attn_mode}' is not available. Falling "
|
||||
f"back to native attention. This is not a problem in "
|
||||
f"Pytorch >= 2.0. FYI, you are running with PyTorch "
|
||||
f"version {torch.__version__}."
|
||||
)
|
||||
attn_mode = "softmax"
|
||||
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
||||
print(
|
||||
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
|
||||
logpy.warn(
|
||||
"We do not support vanilla attention anymore, as it is too "
|
||||
"expensive. Sorry."
|
||||
)
|
||||
if not XFORMERS_IS_AVAILABLE:
|
||||
assert (
|
||||
False
|
||||
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||
else:
|
||||
print("Falling back to xformers efficient attention.")
|
||||
logpy.info("Falling back to xformers efficient attention.")
|
||||
attn_mode = "softmax-xformers"
|
||||
attn_cls = self.ATTENTION_MODES[attn_mode]
|
||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||
@@ -437,7 +522,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
if self.checkpoint:
|
||||
print(f"{self.__class__.__name__} is using checkpointing")
|
||||
logpy.debug(f"{self.__class__.__name__} is using checkpointing")
|
||||
|
||||
def forward(
|
||||
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
||||
@@ -456,9 +541,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
)
|
||||
|
||||
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
||||
return checkpoint(
|
||||
self._forward, (x, context), self.parameters(), self.checkpoint
|
||||
)
|
||||
if self.checkpoint:
|
||||
# inputs = {"x": x, "context": context}
|
||||
return checkpoint(self._forward, x, context)
|
||||
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
||||
else:
|
||||
return self._forward(**kwargs)
|
||||
|
||||
def _forward(
|
||||
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
||||
@@ -518,9 +606,9 @@ class BasicTransformerSingleLayerBlock(nn.Module):
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(
|
||||
self._forward, (x, context), self.parameters(), self.checkpoint
|
||||
)
|
||||
# inputs = {"x": x, "context": context}
|
||||
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
||||
return checkpoint(self._forward, x, context)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x), context=context) + x
|
||||
@@ -554,18 +642,20 @@ class SpatialTransformer(nn.Module):
|
||||
sdp_backend=None,
|
||||
):
|
||||
super().__init__()
|
||||
print(
|
||||
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
|
||||
logpy.debug(
|
||||
f"constructing {self.__class__.__name__} of depth {depth} w/ "
|
||||
f"{in_channels} channels and {n_heads} heads."
|
||||
)
|
||||
from omegaconf import ListConfig
|
||||
|
||||
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
||||
if exists(context_dim) and not isinstance(context_dim, list):
|
||||
context_dim = [context_dim]
|
||||
if exists(context_dim) and isinstance(context_dim, list):
|
||||
if depth != len(context_dim):
|
||||
print(
|
||||
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
|
||||
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
|
||||
logpy.warn(
|
||||
f"{self.__class__.__name__}: Found context dims "
|
||||
f"{context_dim} of depth {len(context_dim)}, which does not "
|
||||
f"match the specified 'depth' of {depth}. Setting context_dim "
|
||||
f"to {depth * [context_dim[0]]} now."
|
||||
)
|
||||
# depth does not match context dims.
|
||||
assert all(
|
||||
@@ -631,3 +721,39 @@ class SpatialTransformer(nn.Module):
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
class SimpleTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
depth: int,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
context_dim: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
checkpoint: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
BasicTransformerBlock(
|
||||
dim,
|
||||
heads,
|
||||
dim_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
attn_mode="softmax-xformers",
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(x, context)
|
||||
return x
|
||||
|
||||
@@ -1,246 +1,7 @@
|
||||
from typing import Any, Union
|
||||
__all__ = [
|
||||
"GeneralLPIPSWithDiscriminator",
|
||||
"LatentLPIPS",
|
||||
]
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
from ....util import default, instantiate_from_config
|
||||
from ..lpips.loss.lpips import LPIPS
|
||||
from ..lpips.model.model import NLayerDiscriminator, weights_init
|
||||
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
|
||||
|
||||
|
||||
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
||||
if global_step < threshold:
|
||||
weight = value
|
||||
return weight
|
||||
|
||||
|
||||
class LatentLPIPS(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
decoder_config,
|
||||
perceptual_weight=1.0,
|
||||
latent_weight=1.0,
|
||||
scale_input_to_tgt_size=False,
|
||||
scale_tgt_to_input_size=False,
|
||||
perceptual_weight_on_inputs=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
||||
self.scale_tgt_to_input_size = scale_tgt_to_input_size
|
||||
self.init_decoder(decoder_config)
|
||||
self.perceptual_loss = LPIPS().eval()
|
||||
self.perceptual_weight = perceptual_weight
|
||||
self.latent_weight = latent_weight
|
||||
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
|
||||
|
||||
def init_decoder(self, config):
|
||||
self.decoder = instantiate_from_config(config)
|
||||
if hasattr(self.decoder, "encoder"):
|
||||
del self.decoder.encoder
|
||||
|
||||
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
|
||||
log = dict()
|
||||
loss = (latent_inputs - latent_predictions) ** 2
|
||||
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
|
||||
image_reconstructions = None
|
||||
if self.perceptual_weight > 0.0:
|
||||
image_reconstructions = self.decoder.decode(latent_predictions)
|
||||
image_targets = self.decoder.decode(latent_inputs)
|
||||
perceptual_loss = self.perceptual_loss(
|
||||
image_targets.contiguous(), image_reconstructions.contiguous()
|
||||
)
|
||||
loss = (
|
||||
self.latent_weight * loss.mean()
|
||||
+ self.perceptual_weight * perceptual_loss.mean()
|
||||
)
|
||||
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
|
||||
|
||||
if self.perceptual_weight_on_inputs > 0.0:
|
||||
image_reconstructions = default(
|
||||
image_reconstructions, self.decoder.decode(latent_predictions)
|
||||
)
|
||||
if self.scale_input_to_tgt_size:
|
||||
image_inputs = torch.nn.functional.interpolate(
|
||||
image_inputs,
|
||||
image_reconstructions.shape[2:],
|
||||
mode="bicubic",
|
||||
antialias=True,
|
||||
)
|
||||
elif self.scale_tgt_to_input_size:
|
||||
image_reconstructions = torch.nn.functional.interpolate(
|
||||
image_reconstructions,
|
||||
image_inputs.shape[2:],
|
||||
mode="bicubic",
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
perceptual_loss2 = self.perceptual_loss(
|
||||
image_inputs.contiguous(), image_reconstructions.contiguous()
|
||||
)
|
||||
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
|
||||
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
|
||||
return loss, log
|
||||
|
||||
|
||||
class GeneralLPIPSWithDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
disc_start: int,
|
||||
logvar_init: float = 0.0,
|
||||
pixelloss_weight=1.0,
|
||||
disc_num_layers: int = 3,
|
||||
disc_in_channels: int = 3,
|
||||
disc_factor: float = 1.0,
|
||||
disc_weight: float = 1.0,
|
||||
perceptual_weight: float = 1.0,
|
||||
disc_loss: str = "hinge",
|
||||
scale_input_to_tgt_size: bool = False,
|
||||
dims: int = 2,
|
||||
learn_logvar: bool = False,
|
||||
regularization_weights: Union[None, dict] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
if self.dims > 2:
|
||||
print(
|
||||
f"running with dims={dims}. This means that for perceptual loss calculation, "
|
||||
f"the LPIPS loss will be applied to each frame independently. "
|
||||
)
|
||||
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
||||
assert disc_loss in ["hinge", "vanilla"]
|
||||
self.pixel_weight = pixelloss_weight
|
||||
self.perceptual_loss = LPIPS().eval()
|
||||
self.perceptual_weight = perceptual_weight
|
||||
# output log variance
|
||||
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
||||
self.learn_logvar = learn_logvar
|
||||
|
||||
self.discriminator = NLayerDiscriminator(
|
||||
input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False
|
||||
).apply(weights_init)
|
||||
self.discriminator_iter_start = disc_start
|
||||
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
||||
self.disc_factor = disc_factor
|
||||
self.discriminator_weight = disc_weight
|
||||
self.regularization_weights = default(regularization_weights, {})
|
||||
|
||||
def get_trainable_parameters(self) -> Any:
|
||||
return self.discriminator.parameters()
|
||||
|
||||
def get_trainable_autoencoder_parameters(self) -> Any:
|
||||
if self.learn_logvar:
|
||||
yield self.logvar
|
||||
yield from ()
|
||||
|
||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
||||
if last_layer is not None:
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||
else:
|
||||
nll_grads = torch.autograd.grad(
|
||||
nll_loss, self.last_layer[0], retain_graph=True
|
||||
)[0]
|
||||
g_grads = torch.autograd.grad(
|
||||
g_loss, self.last_layer[0], retain_graph=True
|
||||
)[0]
|
||||
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
d_weight = d_weight * self.discriminator_weight
|
||||
return d_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
regularization_log,
|
||||
inputs,
|
||||
reconstructions,
|
||||
optimizer_idx,
|
||||
global_step,
|
||||
last_layer=None,
|
||||
split="train",
|
||||
weights=None,
|
||||
):
|
||||
if self.scale_input_to_tgt_size:
|
||||
inputs = torch.nn.functional.interpolate(
|
||||
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
|
||||
)
|
||||
|
||||
if self.dims > 2:
|
||||
inputs, reconstructions = map(
|
||||
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
|
||||
(inputs, reconstructions),
|
||||
)
|
||||
|
||||
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
||||
if self.perceptual_weight > 0:
|
||||
p_loss = self.perceptual_loss(
|
||||
inputs.contiguous(), reconstructions.contiguous()
|
||||
)
|
||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||
|
||||
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
||||
weighted_nll_loss = nll_loss
|
||||
if weights is not None:
|
||||
weighted_nll_loss = weights * nll_loss
|
||||
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
||||
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||
|
||||
# now the GAN part
|
||||
if optimizer_idx == 0:
|
||||
# generator update
|
||||
logits_fake = self.discriminator(reconstructions.contiguous())
|
||||
g_loss = -torch.mean(logits_fake)
|
||||
|
||||
if self.disc_factor > 0.0:
|
||||
try:
|
||||
d_weight = self.calculate_adaptive_weight(
|
||||
nll_loss, g_loss, last_layer=last_layer
|
||||
)
|
||||
except RuntimeError:
|
||||
assert not self.training
|
||||
d_weight = torch.tensor(0.0)
|
||||
else:
|
||||
d_weight = torch.tensor(0.0)
|
||||
|
||||
disc_factor = adopt_weight(
|
||||
self.disc_factor, global_step, threshold=self.discriminator_iter_start
|
||||
)
|
||||
loss = weighted_nll_loss + d_weight * disc_factor * g_loss
|
||||
log = dict()
|
||||
for k in regularization_log:
|
||||
if k in self.regularization_weights:
|
||||
loss = loss + self.regularization_weights[k] * regularization_log[k]
|
||||
log[f"{split}/{k}"] = regularization_log[k].detach().mean()
|
||||
|
||||
log.update(
|
||||
{
|
||||
"{}/total_loss".format(split): loss.clone().detach().mean(),
|
||||
"{}/logvar".format(split): self.logvar.detach(),
|
||||
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
||||
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
||||
"{}/d_weight".format(split): d_weight.detach(),
|
||||
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
||||
"{}/g_loss".format(split): g_loss.detach().mean(),
|
||||
}
|
||||
)
|
||||
|
||||
return loss, log
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# second pass for discriminator update
|
||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
||||
|
||||
disc_factor = adopt_weight(
|
||||
self.disc_factor, global_step, threshold=self.discriminator_iter_start
|
||||
)
|
||||
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
||||
|
||||
log = {
|
||||
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
||||
"{}/logits_real".format(split): logits_real.detach().mean(),
|
||||
"{}/logits_fake".format(split): logits_fake.detach().mean(),
|
||||
}
|
||||
return d_loss, log
|
||||
from .discriminator_loss import GeneralLPIPSWithDiscriminator
|
||||
from .lpips import LatentLPIPS
|
||||
|
||||
306
sgm/modules/autoencoding/losses/discriminator_loss.py
Normal file
306
sgm/modules/autoencoding/losses/discriminator_loss.py
Normal file
@@ -0,0 +1,306 @@
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from einops import rearrange
|
||||
from matplotlib import colormaps
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from ....util import default, instantiate_from_config
|
||||
from ..lpips.loss.lpips import LPIPS
|
||||
from ..lpips.model.model import weights_init
|
||||
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
|
||||
|
||||
|
||||
class GeneralLPIPSWithDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
disc_start: int,
|
||||
logvar_init: float = 0.0,
|
||||
disc_num_layers: int = 3,
|
||||
disc_in_channels: int = 3,
|
||||
disc_factor: float = 1.0,
|
||||
disc_weight: float = 1.0,
|
||||
perceptual_weight: float = 1.0,
|
||||
disc_loss: str = "hinge",
|
||||
scale_input_to_tgt_size: bool = False,
|
||||
dims: int = 2,
|
||||
learn_logvar: bool = False,
|
||||
regularization_weights: Union[None, Dict[str, float]] = None,
|
||||
additional_log_keys: Optional[List[str]] = None,
|
||||
discriminator_config: Optional[Dict] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
if self.dims > 2:
|
||||
print(
|
||||
f"running with dims={dims}. This means that for perceptual loss "
|
||||
f"calculation, the LPIPS loss will be applied to each frame "
|
||||
f"independently."
|
||||
)
|
||||
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
||||
assert disc_loss in ["hinge", "vanilla"]
|
||||
self.perceptual_loss = LPIPS().eval()
|
||||
self.perceptual_weight = perceptual_weight
|
||||
# output log variance
|
||||
self.logvar = nn.Parameter(
|
||||
torch.full((), logvar_init), requires_grad=learn_logvar
|
||||
)
|
||||
self.learn_logvar = learn_logvar
|
||||
|
||||
discriminator_config = default(
|
||||
discriminator_config,
|
||||
{
|
||||
"target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
|
||||
"params": {
|
||||
"input_nc": disc_in_channels,
|
||||
"n_layers": disc_num_layers,
|
||||
"use_actnorm": False,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
self.discriminator = instantiate_from_config(discriminator_config).apply(
|
||||
weights_init
|
||||
)
|
||||
self.discriminator_iter_start = disc_start
|
||||
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
||||
self.disc_factor = disc_factor
|
||||
self.discriminator_weight = disc_weight
|
||||
self.regularization_weights = default(regularization_weights, {})
|
||||
|
||||
self.forward_keys = [
|
||||
"optimizer_idx",
|
||||
"global_step",
|
||||
"last_layer",
|
||||
"split",
|
||||
"regularization_log",
|
||||
]
|
||||
|
||||
self.additional_log_keys = set(default(additional_log_keys, []))
|
||||
self.additional_log_keys.update(set(self.regularization_weights.keys()))
|
||||
|
||||
def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
|
||||
return self.discriminator.parameters()
|
||||
|
||||
def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
|
||||
if self.learn_logvar:
|
||||
yield self.logvar
|
||||
yield from ()
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(
|
||||
self, inputs: torch.Tensor, reconstructions: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# calc logits of real/fake
|
||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||
if len(logits_real.shape) < 4:
|
||||
# Non patch-discriminator
|
||||
return dict()
|
||||
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
||||
# -> (b, 1, h, w)
|
||||
|
||||
# parameters for colormapping
|
||||
high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
|
||||
cmap = colormaps["PiYG"] # diverging colormap
|
||||
|
||||
def to_colormap(logits: torch.Tensor) -> torch.Tensor:
|
||||
"""(b, 1, ...) -> (b, 3, ...)"""
|
||||
logits = (logits + high) / (2 * high)
|
||||
logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
|
||||
# -> (b, 1, ..., 3)
|
||||
logits = torch.from_numpy(logits_np).to(logits.device)
|
||||
return rearrange(logits, "b 1 ... c -> b c ...")
|
||||
|
||||
logits_real = torch.nn.functional.interpolate(
|
||||
logits_real,
|
||||
size=inputs.shape[-2:],
|
||||
mode="nearest",
|
||||
antialias=False,
|
||||
)
|
||||
logits_fake = torch.nn.functional.interpolate(
|
||||
logits_fake,
|
||||
size=reconstructions.shape[-2:],
|
||||
mode="nearest",
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
# alpha value of logits for overlay
|
||||
alpha_real = torch.abs(logits_real) / high
|
||||
alpha_fake = torch.abs(logits_fake) / high
|
||||
# -> (b, 1, h, w) in range [0, 0.5]
|
||||
# alpha value of lines don't really matter, since the values are the same
|
||||
# for both images and logits anyway
|
||||
grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
|
||||
grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
|
||||
grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
|
||||
# -> (1, h, w)
|
||||
# blend logits and images together
|
||||
|
||||
# prepare logits for plotting
|
||||
logits_real = to_colormap(logits_real)
|
||||
logits_fake = to_colormap(logits_fake)
|
||||
# resize logits
|
||||
# -> (b, 3, h, w)
|
||||
|
||||
# make some grids
|
||||
# add all logits to one plot
|
||||
logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
|
||||
logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
|
||||
# I just love how torchvision calls the number of columns `nrow`
|
||||
grid_logits = torch.cat((logits_real, logits_fake), dim=1)
|
||||
# -> (3, h, w)
|
||||
|
||||
grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
|
||||
grid_images_fake = torchvision.utils.make_grid(
|
||||
0.5 * reconstructions + 0.5, nrow=4
|
||||
)
|
||||
grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
|
||||
# -> (3, h, w) in range [0, 1]
|
||||
|
||||
grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
|
||||
|
||||
# Create labeled colorbar
|
||||
dpi = 100
|
||||
height = 128 / dpi
|
||||
width = grid_logits.shape[2] / dpi
|
||||
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
|
||||
img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
|
||||
plt.colorbar(
|
||||
img,
|
||||
cax=ax,
|
||||
orientation="horizontal",
|
||||
fraction=0.9,
|
||||
aspect=width / height,
|
||||
pad=0.0,
|
||||
)
|
||||
img.set_visible(False)
|
||||
fig.tight_layout()
|
||||
fig.canvas.draw()
|
||||
# manually convert figure to numpy
|
||||
cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
||||
cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
|
||||
cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
|
||||
|
||||
# Add colorbar to plot
|
||||
annotated_grid = torch.cat((grid_logits, cbar), dim=1)
|
||||
blended_grid = torch.cat((grid_blend, cbar), dim=1)
|
||||
return {
|
||||
"vis_logits": 2 * annotated_grid[None, ...] - 1,
|
||||
"vis_logits_blended": 2 * blended_grid[None, ...] - 1,
|
||||
}
|
||||
|
||||
def calculate_adaptive_weight(
|
||||
self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
d_weight = d_weight * self.discriminator_weight
|
||||
return d_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
reconstructions: torch.Tensor,
|
||||
*, # added because I changed the order here
|
||||
regularization_log: Dict[str, torch.Tensor],
|
||||
optimizer_idx: int,
|
||||
global_step: int,
|
||||
last_layer: torch.Tensor,
|
||||
split: str = "train",
|
||||
weights: Union[None, float, torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, dict]:
|
||||
if self.scale_input_to_tgt_size:
|
||||
inputs = torch.nn.functional.interpolate(
|
||||
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
|
||||
)
|
||||
|
||||
if self.dims > 2:
|
||||
inputs, reconstructions = map(
|
||||
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
|
||||
(inputs, reconstructions),
|
||||
)
|
||||
|
||||
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
||||
if self.perceptual_weight > 0:
|
||||
p_loss = self.perceptual_loss(
|
||||
inputs.contiguous(), reconstructions.contiguous()
|
||||
)
|
||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||
|
||||
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
|
||||
|
||||
# now the GAN part
|
||||
if optimizer_idx == 0:
|
||||
# generator update
|
||||
if global_step >= self.discriminator_iter_start or not self.training:
|
||||
logits_fake = self.discriminator(reconstructions.contiguous())
|
||||
g_loss = -torch.mean(logits_fake)
|
||||
if self.training:
|
||||
d_weight = self.calculate_adaptive_weight(
|
||||
nll_loss, g_loss, last_layer=last_layer
|
||||
)
|
||||
else:
|
||||
d_weight = torch.tensor(1.0)
|
||||
else:
|
||||
d_weight = torch.tensor(0.0)
|
||||
g_loss = torch.tensor(0.0, requires_grad=True)
|
||||
|
||||
loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
|
||||
log = dict()
|
||||
for k in regularization_log:
|
||||
if k in self.regularization_weights:
|
||||
loss = loss + self.regularization_weights[k] * regularization_log[k]
|
||||
if k in self.additional_log_keys:
|
||||
log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
|
||||
|
||||
log.update(
|
||||
{
|
||||
f"{split}/loss/total": loss.clone().detach().mean(),
|
||||
f"{split}/loss/nll": nll_loss.detach().mean(),
|
||||
f"{split}/loss/rec": rec_loss.detach().mean(),
|
||||
f"{split}/loss/g": g_loss.detach().mean(),
|
||||
f"{split}/scalars/logvar": self.logvar.detach(),
|
||||
f"{split}/scalars/d_weight": d_weight.detach(),
|
||||
}
|
||||
)
|
||||
|
||||
return loss, log
|
||||
elif optimizer_idx == 1:
|
||||
# second pass for discriminator update
|
||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
||||
|
||||
if global_step >= self.discriminator_iter_start or not self.training:
|
||||
d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
|
||||
else:
|
||||
d_loss = torch.tensor(0.0, requires_grad=True)
|
||||
|
||||
log = {
|
||||
f"{split}/loss/disc": d_loss.clone().detach().mean(),
|
||||
f"{split}/logits/real": logits_real.detach().mean(),
|
||||
f"{split}/logits/fake": logits_fake.detach().mean(),
|
||||
}
|
||||
return d_loss, log
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
|
||||
|
||||
def get_nll_loss(
|
||||
self,
|
||||
rec_loss: torch.Tensor,
|
||||
weights: Optional[Union[float, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
||||
weighted_nll_loss = nll_loss
|
||||
if weights is not None:
|
||||
weighted_nll_loss = weights * nll_loss
|
||||
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
||||
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||
|
||||
return nll_loss, weighted_nll_loss
|
||||
73
sgm/modules/autoencoding/losses/lpips.py
Normal file
73
sgm/modules/autoencoding/losses/lpips.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ....util import default, instantiate_from_config
|
||||
from ..lpips.loss.lpips import LPIPS
|
||||
|
||||
|
||||
class LatentLPIPS(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
decoder_config,
|
||||
perceptual_weight=1.0,
|
||||
latent_weight=1.0,
|
||||
scale_input_to_tgt_size=False,
|
||||
scale_tgt_to_input_size=False,
|
||||
perceptual_weight_on_inputs=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
||||
self.scale_tgt_to_input_size = scale_tgt_to_input_size
|
||||
self.init_decoder(decoder_config)
|
||||
self.perceptual_loss = LPIPS().eval()
|
||||
self.perceptual_weight = perceptual_weight
|
||||
self.latent_weight = latent_weight
|
||||
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
|
||||
|
||||
def init_decoder(self, config):
|
||||
self.decoder = instantiate_from_config(config)
|
||||
if hasattr(self.decoder, "encoder"):
|
||||
del self.decoder.encoder
|
||||
|
||||
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
|
||||
log = dict()
|
||||
loss = (latent_inputs - latent_predictions) ** 2
|
||||
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
|
||||
image_reconstructions = None
|
||||
if self.perceptual_weight > 0.0:
|
||||
image_reconstructions = self.decoder.decode(latent_predictions)
|
||||
image_targets = self.decoder.decode(latent_inputs)
|
||||
perceptual_loss = self.perceptual_loss(
|
||||
image_targets.contiguous(), image_reconstructions.contiguous()
|
||||
)
|
||||
loss = (
|
||||
self.latent_weight * loss.mean()
|
||||
+ self.perceptual_weight * perceptual_loss.mean()
|
||||
)
|
||||
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
|
||||
|
||||
if self.perceptual_weight_on_inputs > 0.0:
|
||||
image_reconstructions = default(
|
||||
image_reconstructions, self.decoder.decode(latent_predictions)
|
||||
)
|
||||
if self.scale_input_to_tgt_size:
|
||||
image_inputs = torch.nn.functional.interpolate(
|
||||
image_inputs,
|
||||
image_reconstructions.shape[2:],
|
||||
mode="bicubic",
|
||||
antialias=True,
|
||||
)
|
||||
elif self.scale_tgt_to_input_size:
|
||||
image_reconstructions = torch.nn.functional.interpolate(
|
||||
image_reconstructions,
|
||||
image_inputs.shape[2:],
|
||||
mode="bicubic",
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
perceptual_loss2 = self.perceptual_loss(
|
||||
image_inputs.contiguous(), image_reconstructions.contiguous()
|
||||
)
|
||||
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
|
||||
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
|
||||
return loss, log
|
||||
@@ -5,19 +5,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ....modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
|
||||
class AbstractRegularizer(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_trainable_parameters(self) -> Any:
|
||||
raise NotImplementedError()
|
||||
from ....modules.distributions.distributions import \
|
||||
DiagonalGaussianDistribution
|
||||
from .base import AbstractRegularizer
|
||||
|
||||
|
||||
class DiagonalGaussianRegularizer(AbstractRegularizer):
|
||||
@@ -39,15 +29,3 @@ class DiagonalGaussianRegularizer(AbstractRegularizer):
|
||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||
log["kl_loss"] = kl_loss
|
||||
return z, log
|
||||
|
||||
|
||||
def measure_perplexity(predicted_indices, num_centroids):
|
||||
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
||||
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
||||
encodings = (
|
||||
F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
||||
)
|
||||
avg_probs = encodings.mean(0)
|
||||
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
||||
cluster_use = torch.sum(avg_probs > 0)
|
||||
return perplexity, cluster_use
|
||||
|
||||
40
sgm/modules/autoencoding/regularizers/base.py
Normal file
40
sgm/modules/autoencoding/regularizers/base.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class AbstractRegularizer(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_trainable_parameters(self) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class IdentityRegularizer(AbstractRegularizer):
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||
return z, dict()
|
||||
|
||||
def get_trainable_parameters(self) -> Any:
|
||||
yield from ()
|
||||
|
||||
|
||||
def measure_perplexity(
|
||||
predicted_indices: torch.Tensor, num_centroids: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
||||
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
||||
encodings = (
|
||||
F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
||||
)
|
||||
avg_probs = encodings.mean(0)
|
||||
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
||||
cluster_use = torch.sum(avg_probs > 0)
|
||||
return perplexity, cluster_use
|
||||
487
sgm/modules/autoencoding/regularizers/quantize.py
Normal file
487
sgm/modules/autoencoding/regularizers/quantize.py
Normal file
@@ -0,0 +1,487 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, Iterator, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
|
||||
from .base import AbstractRegularizer, measure_perplexity
|
||||
|
||||
logpy = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractQuantizer(AbstractRegularizer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Define these in your init
|
||||
# shape (N,)
|
||||
self.used: Optional[torch.Tensor]
|
||||
self.re_embed: int
|
||||
self.unknown_index: Union[Literal["random"], int]
|
||||
|
||||
def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor:
|
||||
assert self.used is not None, "You need to define used indices for remap"
|
||||
ishape = inds.shape
|
||||
assert len(ishape) > 1
|
||||
inds = inds.reshape(ishape[0], -1)
|
||||
used = self.used.to(inds)
|
||||
match = (inds[:, :, None] == used[None, None, ...]).long()
|
||||
new = match.argmax(-1)
|
||||
unknown = match.sum(2) < 1
|
||||
if self.unknown_index == "random":
|
||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
|
||||
device=new.device
|
||||
)
|
||||
else:
|
||||
new[unknown] = self.unknown_index
|
||||
return new.reshape(ishape)
|
||||
|
||||
def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor:
|
||||
assert self.used is not None, "You need to define used indices for remap"
|
||||
ishape = inds.shape
|
||||
assert len(ishape) > 1
|
||||
inds = inds.reshape(ishape[0], -1)
|
||||
used = self.used.to(inds)
|
||||
if self.re_embed > self.used.shape[0]: # extra token
|
||||
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
||||
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
||||
return back.reshape(ishape)
|
||||
|
||||
@abstractmethod
|
||||
def get_codebook_entry(
|
||||
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
|
||||
yield from self.parameters()
|
||||
|
||||
|
||||
class GumbelQuantizer(AbstractQuantizer):
|
||||
"""
|
||||
credit to @karpathy:
|
||||
https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
|
||||
Gumbel Softmax trick quantizer
|
||||
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
|
||||
https://arxiv.org/abs/1611.01144
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_hiddens: int,
|
||||
embedding_dim: int,
|
||||
n_embed: int,
|
||||
straight_through: bool = True,
|
||||
kl_weight: float = 5e-4,
|
||||
temp_init: float = 1.0,
|
||||
remap: Optional[str] = None,
|
||||
unknown_index: str = "random",
|
||||
loss_key: str = "loss/vq",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.loss_key = loss_key
|
||||
self.embedding_dim = embedding_dim
|
||||
self.n_embed = n_embed
|
||||
|
||||
self.straight_through = straight_through
|
||||
self.temperature = temp_init
|
||||
self.kl_weight = kl_weight
|
||||
|
||||
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
|
||||
self.embed = nn.Embedding(n_embed, embedding_dim)
|
||||
|
||||
self.remap = remap
|
||||
if self.remap is not None:
|
||||
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
||||
self.re_embed = self.used.shape[0]
|
||||
else:
|
||||
self.used = None
|
||||
self.re_embed = n_embed
|
||||
if unknown_index == "extra":
|
||||
self.unknown_index = self.re_embed
|
||||
self.re_embed = self.re_embed + 1
|
||||
else:
|
||||
assert unknown_index == "random" or isinstance(
|
||||
unknown_index, int
|
||||
), "unknown index needs to be 'random', 'extra' or any integer"
|
||||
self.unknown_index = unknown_index # "random" or "extra" or integer
|
||||
if self.remap is not None:
|
||||
logpy.info(
|
||||
f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
||||
f"Using {self.unknown_index} for unknown indices."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
# force hard = True when we are in eval mode, as we must quantize.
|
||||
# actually, always true seems to work
|
||||
hard = self.straight_through if self.training else True
|
||||
temp = self.temperature if temp is None else temp
|
||||
out_dict = {}
|
||||
logits = self.proj(z)
|
||||
if self.remap is not None:
|
||||
# continue only with used logits
|
||||
full_zeros = torch.zeros_like(logits)
|
||||
logits = logits[:, self.used, ...]
|
||||
|
||||
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
|
||||
if self.remap is not None:
|
||||
# go back to all entries but unused set to zero
|
||||
full_zeros[:, self.used, ...] = soft_one_hot
|
||||
soft_one_hot = full_zeros
|
||||
z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
||||
|
||||
# + kl divergence to the prior loss
|
||||
qy = F.softmax(logits, dim=1)
|
||||
diff = (
|
||||
self.kl_weight
|
||||
* torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
|
||||
)
|
||||
out_dict[self.loss_key] = diff
|
||||
|
||||
ind = soft_one_hot.argmax(dim=1)
|
||||
out_dict["indices"] = ind
|
||||
if self.remap is not None:
|
||||
ind = self.remap_to_used(ind)
|
||||
|
||||
if return_logits:
|
||||
out_dict["logits"] = logits
|
||||
|
||||
return z_q, out_dict
|
||||
|
||||
def get_codebook_entry(self, indices, shape):
|
||||
# TODO: shape not yet optional
|
||||
b, h, w, c = shape
|
||||
assert b * h * w == indices.shape[0]
|
||||
indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
|
||||
if self.remap is not None:
|
||||
indices = self.unmap_to_all(indices)
|
||||
one_hot = (
|
||||
F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
|
||||
)
|
||||
z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
|
||||
return z_q
|
||||
|
||||
|
||||
class VectorQuantizer(AbstractQuantizer):
|
||||
"""
|
||||
____________________________________________
|
||||
Discretization bottleneck part of the VQ-VAE.
|
||||
Inputs:
|
||||
- n_e : number of embeddings
|
||||
- e_dim : dimension of embedding
|
||||
- beta : commitment cost used in loss term,
|
||||
beta * ||z_e(x)-sg[e]||^2
|
||||
_____________________________________________
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_e: int,
|
||||
e_dim: int,
|
||||
beta: float = 0.25,
|
||||
remap: Optional[str] = None,
|
||||
unknown_index: str = "random",
|
||||
sane_index_shape: bool = False,
|
||||
log_perplexity: bool = False,
|
||||
embedding_weight_norm: bool = False,
|
||||
loss_key: str = "loss/vq",
|
||||
):
|
||||
super().__init__()
|
||||
self.n_e = n_e
|
||||
self.e_dim = e_dim
|
||||
self.beta = beta
|
||||
self.loss_key = loss_key
|
||||
|
||||
if not embedding_weight_norm:
|
||||
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||
else:
|
||||
self.embedding = torch.nn.utils.weight_norm(
|
||||
nn.Embedding(self.n_e, self.e_dim), dim=1
|
||||
)
|
||||
|
||||
self.remap = remap
|
||||
if self.remap is not None:
|
||||
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
||||
self.re_embed = self.used.shape[0]
|
||||
else:
|
||||
self.used = None
|
||||
self.re_embed = n_e
|
||||
if unknown_index == "extra":
|
||||
self.unknown_index = self.re_embed
|
||||
self.re_embed = self.re_embed + 1
|
||||
else:
|
||||
assert unknown_index == "random" or isinstance(
|
||||
unknown_index, int
|
||||
), "unknown index needs to be 'random', 'extra' or any integer"
|
||||
self.unknown_index = unknown_index # "random" or "extra" or integer
|
||||
if self.remap is not None:
|
||||
logpy.info(
|
||||
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
||||
f"Using {self.unknown_index} for unknown indices."
|
||||
)
|
||||
|
||||
self.sane_index_shape = sane_index_shape
|
||||
self.log_perplexity = log_perplexity
|
||||
|
||||
def forward(
|
||||
self,
|
||||
z: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
do_reshape = z.ndim == 4
|
||||
if do_reshape:
|
||||
# # reshape z -> (batch, height, width, channel) and flatten
|
||||
z = rearrange(z, "b c h w -> b h w c").contiguous()
|
||||
|
||||
else:
|
||||
assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined"
|
||||
z = z.contiguous()
|
||||
|
||||
z_flattened = z.view(-1, self.e_dim)
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
|
||||
d = (
|
||||
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
||||
+ torch.sum(self.embedding.weight**2, dim=1)
|
||||
- 2
|
||||
* torch.einsum(
|
||||
"bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")
|
||||
)
|
||||
)
|
||||
|
||||
min_encoding_indices = torch.argmin(d, dim=1)
|
||||
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
||||
loss_dict = {}
|
||||
if self.log_perplexity:
|
||||
perplexity, cluster_usage = measure_perplexity(
|
||||
min_encoding_indices.detach(), self.n_e
|
||||
)
|
||||
loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage})
|
||||
|
||||
# compute loss for embedding
|
||||
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
|
||||
(z_q - z.detach()) ** 2
|
||||
)
|
||||
loss_dict[self.loss_key] = loss
|
||||
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# reshape back to match original input shape
|
||||
if do_reshape:
|
||||
z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
|
||||
|
||||
if self.remap is not None:
|
||||
min_encoding_indices = min_encoding_indices.reshape(
|
||||
z.shape[0], -1
|
||||
) # add batch axis
|
||||
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
||||
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
||||
|
||||
if self.sane_index_shape:
|
||||
if do_reshape:
|
||||
min_encoding_indices = min_encoding_indices.reshape(
|
||||
z_q.shape[0], z_q.shape[2], z_q.shape[3]
|
||||
)
|
||||
else:
|
||||
min_encoding_indices = rearrange(
|
||||
min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]
|
||||
)
|
||||
|
||||
loss_dict["min_encoding_indices"] = min_encoding_indices
|
||||
|
||||
return z_q, loss_dict
|
||||
|
||||
def get_codebook_entry(
|
||||
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
|
||||
) -> torch.Tensor:
|
||||
# shape specifying (batch, height, width, channel)
|
||||
if self.remap is not None:
|
||||
assert shape is not None, "Need to give shape for remap"
|
||||
indices = indices.reshape(shape[0], -1) # add batch axis
|
||||
indices = self.unmap_to_all(indices)
|
||||
indices = indices.reshape(-1) # flatten again
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = self.embedding(indices)
|
||||
|
||||
if shape is not None:
|
||||
z_q = z_q.view(shape)
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
|
||||
|
||||
class EmbeddingEMA(nn.Module):
|
||||
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
self.eps = eps
|
||||
weight = torch.randn(num_tokens, codebook_dim)
|
||||
self.weight = nn.Parameter(weight, requires_grad=False)
|
||||
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
|
||||
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
|
||||
self.update = True
|
||||
|
||||
def forward(self, embed_id):
|
||||
return F.embedding(embed_id, self.weight)
|
||||
|
||||
def cluster_size_ema_update(self, new_cluster_size):
|
||||
self.cluster_size.data.mul_(self.decay).add_(
|
||||
new_cluster_size, alpha=1 - self.decay
|
||||
)
|
||||
|
||||
def embed_avg_ema_update(self, new_embed_avg):
|
||||
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
|
||||
|
||||
def weight_update(self, num_tokens):
|
||||
n = self.cluster_size.sum()
|
||||
smoothed_cluster_size = (
|
||||
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
||||
)
|
||||
# normalize embedding average with smoothed cluster size
|
||||
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
||||
self.weight.data.copy_(embed_normalized)
|
||||
|
||||
|
||||
class EMAVectorQuantizer(AbstractQuantizer):
|
||||
def __init__(
|
||||
self,
|
||||
n_embed: int,
|
||||
embedding_dim: int,
|
||||
beta: float,
|
||||
decay: float = 0.99,
|
||||
eps: float = 1e-5,
|
||||
remap: Optional[str] = None,
|
||||
unknown_index: str = "random",
|
||||
loss_key: str = "loss/vq",
|
||||
):
|
||||
super().__init__()
|
||||
self.codebook_dim = embedding_dim
|
||||
self.num_tokens = n_embed
|
||||
self.beta = beta
|
||||
self.loss_key = loss_key
|
||||
|
||||
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
|
||||
|
||||
self.remap = remap
|
||||
if self.remap is not None:
|
||||
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
||||
self.re_embed = self.used.shape[0]
|
||||
else:
|
||||
self.used = None
|
||||
self.re_embed = n_embed
|
||||
if unknown_index == "extra":
|
||||
self.unknown_index = self.re_embed
|
||||
self.re_embed = self.re_embed + 1
|
||||
else:
|
||||
assert unknown_index == "random" or isinstance(
|
||||
unknown_index, int
|
||||
), "unknown index needs to be 'random', 'extra' or any integer"
|
||||
self.unknown_index = unknown_index # "random" or "extra" or integer
|
||||
if self.remap is not None:
|
||||
logpy.info(
|
||||
f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
||||
f"Using {self.unknown_index} for unknown indices."
|
||||
)
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
# z, 'b c h w -> b h w c'
|
||||
z = rearrange(z, "b c h w -> b h w c")
|
||||
z_flattened = z.reshape(-1, self.codebook_dim)
|
||||
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
d = (
|
||||
z_flattened.pow(2).sum(dim=1, keepdim=True)
|
||||
+ self.embedding.weight.pow(2).sum(dim=1)
|
||||
- 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight)
|
||||
) # 'n d -> d n'
|
||||
|
||||
encoding_indices = torch.argmin(d, dim=1)
|
||||
|
||||
z_q = self.embedding(encoding_indices).view(z.shape)
|
||||
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
||||
avg_probs = torch.mean(encodings, dim=0)
|
||||
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
||||
|
||||
if self.training and self.embedding.update:
|
||||
# EMA cluster size
|
||||
encodings_sum = encodings.sum(0)
|
||||
self.embedding.cluster_size_ema_update(encodings_sum)
|
||||
# EMA embedding average
|
||||
embed_sum = encodings.transpose(0, 1) @ z_flattened
|
||||
self.embedding.embed_avg_ema_update(embed_sum)
|
||||
# normalize embed_avg and update weight
|
||||
self.embedding.weight_update(self.num_tokens)
|
||||
|
||||
# compute loss for embedding
|
||||
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
||||
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# reshape back to match original input shape
|
||||
# z_q, 'b h w c -> b c h w'
|
||||
z_q = rearrange(z_q, "b h w c -> b c h w")
|
||||
|
||||
out_dict = {
|
||||
self.loss_key: loss,
|
||||
"encodings": encodings,
|
||||
"encoding_indices": encoding_indices,
|
||||
"perplexity": perplexity,
|
||||
}
|
||||
|
||||
return z_q, out_dict
|
||||
|
||||
|
||||
class VectorQuantizerWithInputProjection(VectorQuantizer):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
n_codes: int,
|
||||
codebook_dim: int,
|
||||
beta: float = 1.0,
|
||||
output_dim: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(n_codes, codebook_dim, beta, **kwargs)
|
||||
self.proj_in = nn.Linear(input_dim, codebook_dim)
|
||||
self.output_dim = output_dim
|
||||
if output_dim is not None:
|
||||
self.proj_out = nn.Linear(codebook_dim, output_dim)
|
||||
else:
|
||||
self.proj_out = nn.Identity()
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
|
||||
rearr = False
|
||||
in_shape = z.shape
|
||||
|
||||
if z.ndim > 3:
|
||||
rearr = self.output_dim is not None
|
||||
z = rearrange(z, "b c ... -> b (...) c")
|
||||
z = self.proj_in(z)
|
||||
z_q, loss_dict = super().forward(z)
|
||||
|
||||
z_q = self.proj_out(z_q)
|
||||
if rearr:
|
||||
if len(in_shape) == 4:
|
||||
z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1])
|
||||
elif len(in_shape) == 5:
|
||||
z_q = rearrange(
|
||||
z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"rearranging not available for {len(in_shape)}-dimensional input."
|
||||
)
|
||||
|
||||
return z_q, loss_dict
|
||||
349
sgm/modules/autoencoding/temporal_ae.py
Normal file
349
sgm/modules/autoencoding/temporal_ae.py
Normal file
@@ -0,0 +1,349 @@
|
||||
from typing import Callable, Iterable, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from sgm.modules.diffusionmodules.model import (
|
||||
XFORMERS_IS_AVAILABLE,
|
||||
AttnBlock,
|
||||
Decoder,
|
||||
MemoryEfficientAttnBlock,
|
||||
ResnetBlock,
|
||||
)
|
||||
from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding
|
||||
from sgm.modules.video_attention import VideoTransformerBlock
|
||||
from sgm.util import partialclass
|
||||
|
||||
|
||||
class VideoResBlock(ResnetBlock):
|
||||
def __init__(
|
||||
self,
|
||||
out_channels,
|
||||
*args,
|
||||
dropout=0.0,
|
||||
video_kernel_size=3,
|
||||
alpha=0.0,
|
||||
merge_strategy="learned",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
|
||||
if video_kernel_size is None:
|
||||
video_kernel_size = [3, 1, 1]
|
||||
self.time_stack = ResBlock(
|
||||
channels=out_channels,
|
||||
emb_channels=0,
|
||||
dropout=dropout,
|
||||
dims=3,
|
||||
use_scale_shift_norm=False,
|
||||
use_conv=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=video_kernel_size,
|
||||
use_checkpoint=False,
|
||||
skip_t_emb=True,
|
||||
)
|
||||
|
||||
self.merge_strategy = merge_strategy
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif self.merge_strategy == "learned":
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def get_alpha(self, bs):
|
||||
if self.merge_strategy == "fixed":
|
||||
return self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
return torch.sigmoid(self.mix_factor)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, x, temb, skip_video=False, timesteps=None):
|
||||
if timesteps is None:
|
||||
timesteps = self.timesteps
|
||||
|
||||
b, c, h, w = x.shape
|
||||
|
||||
x = super().forward(x, temb)
|
||||
|
||||
if not skip_video:
|
||||
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
|
||||
x = self.time_stack(x, temb)
|
||||
|
||||
alpha = self.get_alpha(bs=b // timesteps)
|
||||
x = alpha * x + (1.0 - alpha) * x_mix
|
||||
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
return x
|
||||
|
||||
|
||||
class AE3DConv(torch.nn.Conv2d):
|
||||
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
||||
super().__init__(in_channels, out_channels, *args, **kwargs)
|
||||
if isinstance(video_kernel_size, Iterable):
|
||||
padding = [int(k // 2) for k in video_kernel_size]
|
||||
else:
|
||||
padding = int(video_kernel_size // 2)
|
||||
|
||||
self.time_mix_conv = torch.nn.Conv3d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=video_kernel_size,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
def forward(self, input, timesteps, skip_video=False):
|
||||
x = super().forward(input)
|
||||
if skip_video:
|
||||
return x
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
x = self.time_mix_conv(x)
|
||||
return rearrange(x, "b c t h w -> (b t) c h w")
|
||||
|
||||
|
||||
class VideoBlock(AttnBlock):
|
||||
def __init__(
|
||||
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
|
||||
):
|
||||
super().__init__(in_channels)
|
||||
# no context, single headed, as in base class
|
||||
self.time_mix_block = VideoTransformerBlock(
|
||||
dim=in_channels,
|
||||
n_heads=1,
|
||||
d_head=in_channels,
|
||||
checkpoint=False,
|
||||
ff_in=True,
|
||||
attn_mode="softmax",
|
||||
)
|
||||
|
||||
time_embed_dim = self.in_channels * 4
|
||||
self.video_time_embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(self.in_channels, time_embed_dim),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(time_embed_dim, self.in_channels),
|
||||
)
|
||||
|
||||
self.merge_strategy = merge_strategy
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif self.merge_strategy == "learned":
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def forward(self, x, timesteps, skip_video=False):
|
||||
if skip_video:
|
||||
return super().forward(x)
|
||||
|
||||
x_in = x
|
||||
x = self.attention(x)
|
||||
h, w = x.shape[2:]
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
|
||||
x_mix = x
|
||||
num_frames = torch.arange(timesteps, device=x.device)
|
||||
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
||||
emb = self.video_time_embed(t_emb) # b, n_channels
|
||||
emb = emb[:, None, :]
|
||||
x_mix = x_mix + emb
|
||||
|
||||
alpha = self.get_alpha()
|
||||
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
||||
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
||||
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
x = self.proj_out(x)
|
||||
|
||||
return x_in + x
|
||||
|
||||
def get_alpha(
|
||||
self,
|
||||
):
|
||||
if self.merge_strategy == "fixed":
|
||||
return self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
return torch.sigmoid(self.mix_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
|
||||
class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
|
||||
def __init__(
|
||||
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
|
||||
):
|
||||
super().__init__(in_channels)
|
||||
# no context, single headed, as in base class
|
||||
self.time_mix_block = VideoTransformerBlock(
|
||||
dim=in_channels,
|
||||
n_heads=1,
|
||||
d_head=in_channels,
|
||||
checkpoint=False,
|
||||
ff_in=True,
|
||||
attn_mode="softmax-xformers",
|
||||
)
|
||||
|
||||
time_embed_dim = self.in_channels * 4
|
||||
self.video_time_embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(self.in_channels, time_embed_dim),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(time_embed_dim, self.in_channels),
|
||||
)
|
||||
|
||||
self.merge_strategy = merge_strategy
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif self.merge_strategy == "learned":
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def forward(self, x, timesteps, skip_time_block=False):
|
||||
if skip_time_block:
|
||||
return super().forward(x)
|
||||
|
||||
x_in = x
|
||||
x = self.attention(x)
|
||||
h, w = x.shape[2:]
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
|
||||
x_mix = x
|
||||
num_frames = torch.arange(timesteps, device=x.device)
|
||||
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
||||
emb = self.video_time_embed(t_emb) # b, n_channels
|
||||
emb = emb[:, None, :]
|
||||
x_mix = x_mix + emb
|
||||
|
||||
alpha = self.get_alpha()
|
||||
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
||||
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
||||
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
x = self.proj_out(x)
|
||||
|
||||
return x_in + x
|
||||
|
||||
def get_alpha(
|
||||
self,
|
||||
):
|
||||
if self.merge_strategy == "fixed":
|
||||
return self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
return torch.sigmoid(self.mix_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
|
||||
def make_time_attn(
|
||||
in_channels,
|
||||
attn_type="vanilla",
|
||||
attn_kwargs=None,
|
||||
alpha: float = 0,
|
||||
merge_strategy: str = "learned",
|
||||
):
|
||||
assert attn_type in [
|
||||
"vanilla",
|
||||
"vanilla-xformers",
|
||||
], f"attn_type {attn_type} not supported for spatio-temporal attention"
|
||||
print(
|
||||
f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
|
||||
)
|
||||
if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
|
||||
print(
|
||||
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
|
||||
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
|
||||
)
|
||||
attn_type = "vanilla"
|
||||
|
||||
if attn_type == "vanilla":
|
||||
assert attn_kwargs is None
|
||||
return partialclass(
|
||||
VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
|
||||
)
|
||||
elif attn_type == "vanilla-xformers":
|
||||
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||
return partialclass(
|
||||
MemoryEfficientVideoBlock,
|
||||
in_channels,
|
||||
alpha=alpha,
|
||||
merge_strategy=merge_strategy,
|
||||
)
|
||||
else:
|
||||
return NotImplementedError()
|
||||
|
||||
|
||||
class Conv2DWrapper(torch.nn.Conv2d):
|
||||
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
class VideoDecoder(Decoder):
|
||||
available_time_modes = ["all", "conv-only", "attn-only"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
video_kernel_size: Union[int, list] = 3,
|
||||
alpha: float = 0.0,
|
||||
merge_strategy: str = "learned",
|
||||
time_mode: str = "conv-only",
|
||||
**kwargs,
|
||||
):
|
||||
self.video_kernel_size = video_kernel_size
|
||||
self.alpha = alpha
|
||||
self.merge_strategy = merge_strategy
|
||||
self.time_mode = time_mode
|
||||
assert (
|
||||
self.time_mode in self.available_time_modes
|
||||
), f"time_mode parameter has to be in {self.available_time_modes}"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_last_layer(self, skip_time_mix=False, **kwargs):
|
||||
if self.time_mode == "attn-only":
|
||||
raise NotImplementedError("TODO")
|
||||
else:
|
||||
return (
|
||||
self.conv_out.time_mix_conv.weight
|
||||
if not skip_time_mix
|
||||
else self.conv_out.weight
|
||||
)
|
||||
|
||||
def _make_attn(self) -> Callable:
|
||||
if self.time_mode not in ["conv-only", "only-last-conv"]:
|
||||
return partialclass(
|
||||
make_time_attn,
|
||||
alpha=self.alpha,
|
||||
merge_strategy=self.merge_strategy,
|
||||
)
|
||||
else:
|
||||
return super()._make_attn()
|
||||
|
||||
def _make_conv(self) -> Callable:
|
||||
if self.time_mode != "attn-only":
|
||||
return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
|
||||
else:
|
||||
return Conv2DWrapper
|
||||
|
||||
def _make_resblock(self) -> Callable:
|
||||
if self.time_mode not in ["attn-only", "only-last-conv"]:
|
||||
return partialclass(
|
||||
VideoResBlock,
|
||||
video_kernel_size=self.video_kernel_size,
|
||||
alpha=self.alpha,
|
||||
merge_strategy=self.merge_strategy,
|
||||
)
|
||||
else:
|
||||
return super()._make_resblock()
|
||||
@@ -1,7 +0,0 @@
|
||||
from .denoiser import Denoiser
|
||||
from .discretizer import Discretization
|
||||
from .loss import StandardDiffusionLoss
|
||||
from .model import Decoder, Encoder, Model
|
||||
from .openaimodel import UNetModel
|
||||
from .sampling import BaseDiffusionSampler
|
||||
from .wrappers import OpenAIWrapper
|
||||
|
||||
@@ -1,62 +1,74 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...util import append_dims, instantiate_from_config
|
||||
from .denoiser_scaling import DenoiserScaling
|
||||
from .discretizer import Discretization
|
||||
|
||||
|
||||
class Denoiser(nn.Module):
|
||||
def __init__(self, weighting_config, scaling_config):
|
||||
def __init__(self, scaling_config: Dict):
|
||||
super().__init__()
|
||||
|
||||
self.weighting = instantiate_from_config(weighting_config)
|
||||
self.scaling = instantiate_from_config(scaling_config)
|
||||
self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)
|
||||
|
||||
def possibly_quantize_sigma(self, sigma):
|
||||
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
return sigma
|
||||
|
||||
def possibly_quantize_c_noise(self, c_noise):
|
||||
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
|
||||
return c_noise
|
||||
|
||||
def w(self, sigma):
|
||||
return self.weighting(sigma)
|
||||
|
||||
def __call__(self, network, input, sigma, cond):
|
||||
def forward(
|
||||
self,
|
||||
network: nn.Module,
|
||||
input: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
cond: Dict,
|
||||
**additional_model_inputs,
|
||||
) -> torch.Tensor:
|
||||
sigma = self.possibly_quantize_sigma(sigma)
|
||||
sigma_shape = sigma.shape
|
||||
sigma = append_dims(sigma, input.ndim)
|
||||
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
|
||||
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
|
||||
return network(input * c_in, c_noise, cond) * c_out + input * c_skip
|
||||
return (
|
||||
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
|
||||
+ input * c_skip
|
||||
)
|
||||
|
||||
|
||||
class DiscreteDenoiser(Denoiser):
|
||||
def __init__(
|
||||
self,
|
||||
weighting_config,
|
||||
scaling_config,
|
||||
num_idx,
|
||||
discretization_config,
|
||||
do_append_zero=False,
|
||||
quantize_c_noise=True,
|
||||
flip=True,
|
||||
scaling_config: Dict,
|
||||
num_idx: int,
|
||||
discretization_config: Dict,
|
||||
do_append_zero: bool = False,
|
||||
quantize_c_noise: bool = True,
|
||||
flip: bool = True,
|
||||
):
|
||||
super().__init__(weighting_config, scaling_config)
|
||||
sigmas = instantiate_from_config(discretization_config)(
|
||||
num_idx, do_append_zero=do_append_zero, flip=flip
|
||||
super().__init__(scaling_config)
|
||||
self.discretization: Discretization = instantiate_from_config(
|
||||
discretization_config
|
||||
)
|
||||
sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip)
|
||||
self.register_buffer("sigmas", sigmas)
|
||||
self.quantize_c_noise = quantize_c_noise
|
||||
self.num_idx = num_idx
|
||||
|
||||
def sigma_to_idx(self, sigma):
|
||||
def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
dists = sigma - self.sigmas[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
|
||||
def idx_to_sigma(self, idx):
|
||||
def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor:
|
||||
return self.sigmas[idx]
|
||||
|
||||
def possibly_quantize_sigma(self, sigma):
|
||||
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
return self.idx_to_sigma(self.sigma_to_idx(sigma))
|
||||
|
||||
def possibly_quantize_c_noise(self, c_noise):
|
||||
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
|
||||
if self.quantize_c_noise:
|
||||
return self.sigma_to_idx(c_noise)
|
||||
else:
|
||||
|
||||
@@ -1,11 +1,24 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class DenoiserScaling(ABC):
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self, sigma: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
pass
|
||||
|
||||
|
||||
class EDMScaling:
|
||||
def __init__(self, sigma_data=0.5):
|
||||
def __init__(self, sigma_data: float = 0.5):
|
||||
self.sigma_data = sigma_data
|
||||
|
||||
def __call__(self, sigma):
|
||||
def __call__(
|
||||
self, sigma: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
||||
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||
@@ -14,7 +27,9 @@ class EDMScaling:
|
||||
|
||||
|
||||
class EpsScaling:
|
||||
def __call__(self, sigma):
|
||||
def __call__(
|
||||
self, sigma: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = torch.ones_like(sigma, device=sigma.device)
|
||||
c_out = -sigma
|
||||
c_in = 1 / (sigma**2 + 1.0) ** 0.5
|
||||
@@ -23,9 +38,22 @@ class EpsScaling:
|
||||
|
||||
|
||||
class VScaling:
|
||||
def __call__(self, sigma):
|
||||
def __call__(
|
||||
self, sigma: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = 1.0 / (sigma**2 + 1.0)
|
||||
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
||||
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
||||
c_noise = sigma.clone()
|
||||
return c_skip, c_out, c_in, c_noise
|
||||
|
||||
|
||||
class VScalingWithEDMcNoise(DenoiserScaling):
|
||||
def __call__(
|
||||
self, sigma: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = 1.0 / (sigma**2 + 1.0)
|
||||
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
||||
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
||||
c_noise = 0.25 * sigma.log()
|
||||
return c_skip, c_out, c_in, c_noise
|
||||
|
||||
@@ -1,31 +1,33 @@
|
||||
from functools import partial
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from ...util import default, instantiate_from_config
|
||||
from ...util import append_dims, default
|
||||
|
||||
logpy = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VanillaCFG:
|
||||
"""
|
||||
implements parallelized CFG
|
||||
"""
|
||||
class Guider(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def __init__(self, scale, dyn_thresh_config=None):
|
||||
scale_schedule = lambda scale, sigma: scale # independent of step
|
||||
self.scale_schedule = partial(scale_schedule, scale)
|
||||
self.dyn_thresh = instantiate_from_config(
|
||||
default(
|
||||
dyn_thresh_config,
|
||||
{
|
||||
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
||||
},
|
||||
)
|
||||
)
|
||||
def prepare_inputs(
|
||||
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
|
||||
) -> Tuple[torch.Tensor, float, Dict]:
|
||||
pass
|
||||
|
||||
def __call__(self, x, sigma):
|
||||
|
||||
class VanillaCFG(Guider):
|
||||
def __init__(self, scale: float):
|
||||
self.scale = scale
|
||||
|
||||
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
||||
x_u, x_c = x.chunk(2)
|
||||
scale_value = self.scale_schedule(sigma)
|
||||
x_pred = self.dyn_thresh(x_u, x_c, scale_value)
|
||||
x_pred = x_u + self.scale * (x_c - x_u)
|
||||
return x_pred
|
||||
|
||||
def prepare_inputs(self, x, s, c, uc):
|
||||
@@ -40,14 +42,58 @@ class VanillaCFG:
|
||||
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
||||
|
||||
|
||||
class IdentityGuider:
|
||||
def __call__(self, x, sigma):
|
||||
class IdentityGuider(Guider):
|
||||
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
def prepare_inputs(self, x, s, c, uc):
|
||||
def prepare_inputs(
|
||||
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
|
||||
) -> Tuple[torch.Tensor, float, Dict]:
|
||||
c_out = dict()
|
||||
|
||||
for k in c:
|
||||
c_out[k] = c[k]
|
||||
|
||||
return x, s, c_out
|
||||
|
||||
|
||||
class LinearPredictionGuider(Guider):
|
||||
def __init__(
|
||||
self,
|
||||
max_scale: float,
|
||||
num_frames: int,
|
||||
min_scale: float = 1.0,
|
||||
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
||||
):
|
||||
self.min_scale = min_scale
|
||||
self.max_scale = max_scale
|
||||
self.num_frames = num_frames
|
||||
self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
|
||||
|
||||
additional_cond_keys = default(additional_cond_keys, [])
|
||||
if isinstance(additional_cond_keys, str):
|
||||
additional_cond_keys = [additional_cond_keys]
|
||||
self.additional_cond_keys = additional_cond_keys
|
||||
|
||||
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
||||
x_u, x_c = x.chunk(2)
|
||||
|
||||
x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
|
||||
x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
|
||||
scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
|
||||
scale = append_dims(scale, x_u.ndim).to(x_u.device)
|
||||
|
||||
return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
|
||||
|
||||
def prepare_inputs(
|
||||
self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
c_out = dict()
|
||||
|
||||
for k in c:
|
||||
if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
|
||||
c_out[k] = torch.cat((uc[k], c[k]), 0)
|
||||
else:
|
||||
assert c[k] == uc[k]
|
||||
c_out[k] = c[k]
|
||||
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
||||
|
||||
@@ -1,31 +1,34 @@
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import ListConfig
|
||||
|
||||
from ...util import append_dims, instantiate_from_config
|
||||
from ...modules.autoencoding.lpips.loss.lpips import LPIPS
|
||||
from ...modules.encoders.modules import GeneralConditioner
|
||||
from ...util import append_dims, instantiate_from_config
|
||||
from .denoiser import Denoiser
|
||||
|
||||
|
||||
class StandardDiffusionLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sigma_sampler_config,
|
||||
type="l2",
|
||||
offset_noise_level=0.0,
|
||||
batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
|
||||
sigma_sampler_config: dict,
|
||||
loss_weighting_config: dict,
|
||||
loss_type: str = "l2",
|
||||
offset_noise_level: float = 0.0,
|
||||
batch2model_keys: Optional[Union[str, List[str]]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert type in ["l2", "l1", "lpips"]
|
||||
assert loss_type in ["l2", "l1", "lpips"]
|
||||
|
||||
self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
|
||||
self.loss_weighting = instantiate_from_config(loss_weighting_config)
|
||||
|
||||
self.type = type
|
||||
self.loss_type = loss_type
|
||||
self.offset_noise_level = offset_noise_level
|
||||
|
||||
if type == "lpips":
|
||||
if loss_type == "lpips":
|
||||
self.lpips = LPIPS().eval()
|
||||
|
||||
if not batch2model_keys:
|
||||
@@ -36,34 +39,67 @@ class StandardDiffusionLoss(nn.Module):
|
||||
|
||||
self.batch2model_keys = set(batch2model_keys)
|
||||
|
||||
def __call__(self, network, denoiser, conditioner, input, batch):
|
||||
def get_noised_input(
|
||||
self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
noised_input = input + noise * sigmas_bc
|
||||
return noised_input
|
||||
|
||||
def forward(
|
||||
self,
|
||||
network: nn.Module,
|
||||
denoiser: Denoiser,
|
||||
conditioner: GeneralConditioner,
|
||||
input: torch.Tensor,
|
||||
batch: Dict,
|
||||
) -> torch.Tensor:
|
||||
cond = conditioner(batch)
|
||||
return self._forward(network, denoiser, cond, input, batch)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
network: nn.Module,
|
||||
denoiser: Denoiser,
|
||||
cond: Dict,
|
||||
input: torch.Tensor,
|
||||
batch: Dict,
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
additional_model_inputs = {
|
||||
key: batch[key] for key in self.batch2model_keys.intersection(batch)
|
||||
}
|
||||
sigmas = self.sigma_sampler(input.shape[0]).to(input)
|
||||
|
||||
sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
|
||||
noise = torch.randn_like(input)
|
||||
if self.offset_noise_level > 0.0:
|
||||
noise = noise + self.offset_noise_level * append_dims(
|
||||
torch.randn(input.shape[0], device=input.device), input.ndim
|
||||
offset_shape = (
|
||||
(input.shape[0], 1, input.shape[2])
|
||||
if self.n_frames is not None
|
||||
else (input.shape[0], input.shape[1])
|
||||
)
|
||||
noised_input = input + noise * append_dims(sigmas, input.ndim)
|
||||
noise = noise + self.offset_noise_level * append_dims(
|
||||
torch.randn(offset_shape, device=input.device),
|
||||
input.ndim,
|
||||
)
|
||||
sigmas_bc = append_dims(sigmas, input.ndim)
|
||||
noised_input = self.get_noised_input(sigmas_bc, noise, input)
|
||||
|
||||
model_output = denoiser(
|
||||
network, noised_input, sigmas, cond, **additional_model_inputs
|
||||
)
|
||||
w = append_dims(denoiser.w(sigmas), input.ndim)
|
||||
w = append_dims(self.loss_weighting(sigmas), input.ndim)
|
||||
return self.get_loss(model_output, input, w)
|
||||
|
||||
def get_loss(self, model_output, target, w):
|
||||
if self.type == "l2":
|
||||
if self.loss_type == "l2":
|
||||
return torch.mean(
|
||||
(w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
|
||||
)
|
||||
elif self.type == "l1":
|
||||
elif self.loss_type == "l1":
|
||||
return torch.mean(
|
||||
(w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
|
||||
)
|
||||
elif self.type == "lpips":
|
||||
elif self.loss_type == "lpips":
|
||||
loss = self.lpips(model_output, target).reshape(-1)
|
||||
return loss
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown loss type {self.loss_type}")
|
||||
|
||||
32
sgm/modules/diffusionmodules/loss_weighting.py
Normal file
32
sgm/modules/diffusionmodules/loss_weighting.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class DiffusionLossWeighting(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
class UnitWeighting(DiffusionLossWeighting):
|
||||
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ones_like(sigma, device=sigma.device)
|
||||
|
||||
|
||||
class EDMWeighting(DiffusionLossWeighting):
|
||||
def __init__(self, sigma_data: float = 0.5):
|
||||
self.sigma_data = sigma_data
|
||||
|
||||
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
|
||||
|
||||
|
||||
class VWeighting(EDMWeighting):
|
||||
def __init__(self):
|
||||
super().__init__(sigma_data=1.0)
|
||||
|
||||
|
||||
class EpsWeighting(DiffusionLossWeighting):
|
||||
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
|
||||
return sigma**-2.0
|
||||
@@ -1,4 +1,5 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
@@ -8,6 +9,8 @@ import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
|
||||
logpy = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
@@ -15,7 +18,7 @@ try:
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
print("no module 'xformers'. Processing without...")
|
||||
logpy.warning("no module 'xformers'. Processing without...")
|
||||
|
||||
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
|
||||
|
||||
@@ -288,12 +291,14 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||
)
|
||||
attn_type = "vanilla-xformers"
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
assert attn_kwargs is None
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "vanilla-xformers":
|
||||
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||
logpy.info(
|
||||
f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
|
||||
)
|
||||
return MemoryEfficientAttnBlock(in_channels)
|
||||
elif type == "memory-efficient-cross-attn":
|
||||
attn_kwargs["query_dim"] = in_channels
|
||||
@@ -633,7 +638,7 @@ class Decoder(nn.Module):
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print(
|
||||
logpy.info(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,13 +9,10 @@ import torch
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
from ...modules.diffusionmodules.sampling_utils import (
|
||||
get_ancestral_step,
|
||||
linear_multistep_coeff,
|
||||
to_d,
|
||||
to_neg_log_sigma,
|
||||
to_sigma,
|
||||
)
|
||||
from ...modules.diffusionmodules.sampling_utils import (get_ancestral_step,
|
||||
linear_multistep_coeff,
|
||||
to_d, to_neg_log_sigma,
|
||||
to_sigma)
|
||||
from ...util import append_dims, default, instantiate_from_config
|
||||
|
||||
DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
|
||||
|
||||
@@ -4,11 +4,6 @@ from scipy import integrate
|
||||
from ...util import append_dims
|
||||
|
||||
|
||||
class NoDynamicThresholding:
|
||||
def __call__(self, uncond, cond, scale):
|
||||
return uncond + scale * (cond - uncond)
|
||||
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
|
||||
if order - 1 > i:
|
||||
raise ValueError(f"Order {order} too high for step {i}")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
adopted from
|
||||
partially adopted from
|
||||
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
and
|
||||
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
@@ -10,10 +10,11 @@ thanks!
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import repeat
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
def make_beta_schedule(
|
||||
@@ -306,3 +307,63 @@ def avg_pool_nd(dims, *args, **kwargs):
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class AlphaBlender(nn.Module):
|
||||
strategies = ["learned", "fixed", "learned_with_images"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alpha: float,
|
||||
merge_strategy: str = "learned_with_images",
|
||||
rearrange_pattern: str = "b t -> (b t) 1 1",
|
||||
):
|
||||
super().__init__()
|
||||
self.merge_strategy = merge_strategy
|
||||
self.rearrange_pattern = rearrange_pattern
|
||||
|
||||
assert (
|
||||
merge_strategy in self.strategies
|
||||
), f"merge_strategy needs to be in {self.strategies}"
|
||||
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif (
|
||||
self.merge_strategy == "learned"
|
||||
or self.merge_strategy == "learned_with_images"
|
||||
):
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
|
||||
if self.merge_strategy == "fixed":
|
||||
alpha = self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
alpha = torch.sigmoid(self.mix_factor)
|
||||
elif self.merge_strategy == "learned_with_images":
|
||||
assert image_only_indicator is not None, "need image_only_indicator ..."
|
||||
alpha = torch.where(
|
||||
image_only_indicator.bool(),
|
||||
torch.ones(1, 1, device=image_only_indicator.device),
|
||||
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
|
||||
)
|
||||
alpha = rearrange(alpha, self.rearrange_pattern)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return alpha
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_spatial: torch.Tensor,
|
||||
x_temporal: torch.Tensor,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
alpha = self.get_alpha(image_only_indicator)
|
||||
x = (
|
||||
alpha.to(x_spatial.dtype) * x_spatial
|
||||
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
|
||||
)
|
||||
return x
|
||||
|
||||
493
sgm/modules/diffusionmodules/video_model.py
Normal file
493
sgm/modules/diffusionmodules/video_model.py
Normal file
@@ -0,0 +1,493 @@
|
||||
from functools import partial
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from ...modules.diffusionmodules.openaimodel import *
|
||||
from ...modules.video_attention import SpatialVideoTransformer
|
||||
from ...util import default
|
||||
from .util import AlphaBlender
|
||||
|
||||
|
||||
class VideoResBlock(ResBlock):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
emb_channels: int,
|
||||
dropout: float,
|
||||
video_kernel_size: Union[int, List[int]] = 3,
|
||||
merge_strategy: str = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
use_scale_shift_norm: bool = False,
|
||||
dims: int = 2,
|
||||
use_checkpoint: bool = False,
|
||||
up: bool = False,
|
||||
down: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=out_channels,
|
||||
use_conv=use_conv,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
up=up,
|
||||
down=down,
|
||||
)
|
||||
|
||||
self.time_stack = ResBlock(
|
||||
default(out_channels, channels),
|
||||
emb_channels,
|
||||
dropout=dropout,
|
||||
dims=3,
|
||||
out_channels=default(out_channels, channels),
|
||||
use_scale_shift_norm=False,
|
||||
use_conv=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=video_kernel_size,
|
||||
use_checkpoint=use_checkpoint,
|
||||
exchange_temb_dims=True,
|
||||
)
|
||||
self.time_mixer = AlphaBlender(
|
||||
alpha=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
rearrange_pattern="b t -> b 1 t 1 1",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: th.Tensor,
|
||||
emb: th.Tensor,
|
||||
num_video_frames: int,
|
||||
image_only_indicator: Optional[th.Tensor] = None,
|
||||
) -> th.Tensor:
|
||||
x = super().forward(x, emb)
|
||||
|
||||
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||
|
||||
x = self.time_stack(
|
||||
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
|
||||
)
|
||||
x = self.time_mixer(
|
||||
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
|
||||
)
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
return x
|
||||
|
||||
|
||||
class VideoUNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
out_channels: int,
|
||||
num_res_blocks: int,
|
||||
attention_resolutions: int,
|
||||
dropout: float = 0.0,
|
||||
channel_mult: List[int] = (1, 2, 4, 8),
|
||||
conv_resample: bool = True,
|
||||
dims: int = 2,
|
||||
num_classes: Optional[int] = None,
|
||||
use_checkpoint: bool = False,
|
||||
num_heads: int = -1,
|
||||
num_head_channels: int = -1,
|
||||
num_heads_upsample: int = -1,
|
||||
use_scale_shift_norm: bool = False,
|
||||
resblock_updown: bool = False,
|
||||
transformer_depth: Union[List[int], int] = 1,
|
||||
transformer_depth_middle: Optional[int] = None,
|
||||
context_dim: Optional[int] = None,
|
||||
time_downup: bool = False,
|
||||
time_context_dim: Optional[int] = None,
|
||||
extra_ff_mix_layer: bool = False,
|
||||
use_spatial_context: bool = False,
|
||||
merge_strategy: str = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
spatial_transformer_attn_type: str = "softmax",
|
||||
video_kernel_size: Union[int, List[int]] = 3,
|
||||
use_linear_in_transformer: bool = False,
|
||||
adm_in_channels: Optional[int] = None,
|
||||
disable_temporal_crossattention: bool = False,
|
||||
max_ddpm_temb_period: int = 10000,
|
||||
):
|
||||
super().__init__()
|
||||
assert context_dim is not None
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
if isinstance(transformer_depth, int):
|
||||
transformer_depth = len(channel_mult) * [transformer_depth]
|
||||
transformer_depth_middle = default(
|
||||
transformer_depth_middle, transformer_depth[-1]
|
||||
)
|
||||
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
elif self.num_classes == "continuous":
|
||||
print("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
elif self.num_classes == "timestep":
|
||||
self.label_emb = nn.Sequential(
|
||||
Timestep(model_channels),
|
||||
nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
),
|
||||
)
|
||||
|
||||
elif self.num_classes == "sequential":
|
||||
assert adm_in_channels is not None
|
||||
self.label_emb = nn.Sequential(
|
||||
nn.Sequential(
|
||||
linear(adm_in_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
)
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
|
||||
def get_attention_layer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=1,
|
||||
context_dim=None,
|
||||
use_checkpoint=False,
|
||||
disabled_sa=False,
|
||||
):
|
||||
return SpatialVideoTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=depth,
|
||||
context_dim=context_dim,
|
||||
time_context_dim=time_context_dim,
|
||||
dropout=dropout,
|
||||
ff_in=extra_ff_mix_layer,
|
||||
use_spatial_context=use_spatial_context,
|
||||
merge_strategy=merge_strategy,
|
||||
merge_factor=merge_factor,
|
||||
checkpoint=use_checkpoint,
|
||||
use_linear=use_linear_in_transformer,
|
||||
attn_mode=spatial_transformer_attn_type,
|
||||
disable_self_attn=disabled_sa,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
max_time_embed_period=max_ddpm_temb_period,
|
||||
)
|
||||
|
||||
def get_resblock(
|
||||
merge_factor,
|
||||
merge_strategy,
|
||||
video_kernel_size,
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_ch,
|
||||
dims,
|
||||
use_checkpoint,
|
||||
use_scale_shift_norm,
|
||||
down=False,
|
||||
up=False,
|
||||
):
|
||||
return VideoResBlock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
channels=ch,
|
||||
emb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=down,
|
||||
up=up,
|
||||
)
|
||||
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_ch=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
|
||||
layers.append(
|
||||
get_attention_layer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth[level],
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=use_checkpoint,
|
||||
disabled_sa=False,
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
ds *= 2
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_ch=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch,
|
||||
conv_resample,
|
||||
dims=dims,
|
||||
out_channels=out_ch,
|
||||
third_down=time_downup,
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
out_ch=None,
|
||||
dropout=dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
get_attention_layer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth_middle,
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=use_checkpoint,
|
||||
),
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
out_ch=None,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(num_res_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch + ich,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_ch=model_channels * mult,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
|
||||
layers.append(
|
||||
get_attention_layer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth[level],
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=use_checkpoint,
|
||||
disabled_sa=False,
|
||||
)
|
||||
)
|
||||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
ds //= 2
|
||||
layers.append(
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_ch=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(
|
||||
ch,
|
||||
conv_resample,
|
||||
dims=dims,
|
||||
out_channels=out_ch,
|
||||
third_up=time_downup,
|
||||
)
|
||||
)
|
||||
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: th.Tensor,
|
||||
timesteps: th.Tensor,
|
||||
context: Optional[th.Tensor] = None,
|
||||
y: Optional[th.Tensor] = None,
|
||||
time_context: Optional[th.Tensor] = None,
|
||||
num_video_frames: Optional[int] = None,
|
||||
image_only_indicator: Optional[th.Tensor] = None,
|
||||
):
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x
|
||||
for module in self.input_blocks:
|
||||
h = module(
|
||||
h,
|
||||
emb,
|
||||
context=context,
|
||||
image_only_indicator=image_only_indicator,
|
||||
time_context=time_context,
|
||||
num_video_frames=num_video_frames,
|
||||
)
|
||||
hs.append(h)
|
||||
h = self.middle_block(
|
||||
h,
|
||||
emb,
|
||||
context=context,
|
||||
image_only_indicator=image_only_indicator,
|
||||
time_context=time_context,
|
||||
num_video_frames=num_video_frames,
|
||||
)
|
||||
for module in self.output_blocks:
|
||||
h = th.cat([h, hs.pop()], dim=1)
|
||||
h = module(
|
||||
h,
|
||||
emb,
|
||||
context=context,
|
||||
image_only_indicator=image_only_indicator,
|
||||
time_context=time_context,
|
||||
num_video_frames=num_video_frames,
|
||||
)
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
@@ -10,27 +11,17 @@ import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from omegaconf import ListConfig
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers import (
|
||||
ByT5Tokenizer,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
)
|
||||
from transformers import (ByT5Tokenizer, CLIPTextModel, CLIPTokenizer,
|
||||
T5EncoderModel, T5Tokenizer)
|
||||
|
||||
from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
|
||||
from ...modules.diffusionmodules.model import Encoder
|
||||
from ...modules.diffusionmodules.openaimodel import Timestep
|
||||
from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
|
||||
from ...modules.diffusionmodules.util import (extract_into_tensor,
|
||||
make_beta_schedule)
|
||||
from ...modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
from ...util import (
|
||||
autocast,
|
||||
count_params,
|
||||
default,
|
||||
disabled_train,
|
||||
expand_dims_like,
|
||||
instantiate_from_config,
|
||||
)
|
||||
from ...util import (append_dims, autocast, count_params, default,
|
||||
disabled_train, expand_dims_like, instantiate_from_config)
|
||||
|
||||
|
||||
class AbstractEmbModel(nn.Module):
|
||||
@@ -173,7 +164,11 @@ class GeneralConditioner(nn.Module):
|
||||
return output
|
||||
|
||||
def get_unconditional_conditioning(
|
||||
self, batch_c, batch_uc=None, force_uc_zero_embeddings=None
|
||||
self,
|
||||
batch_c: Dict,
|
||||
batch_uc: Optional[Dict] = None,
|
||||
force_uc_zero_embeddings: Optional[List[str]] = None,
|
||||
force_cond_zero_embeddings: Optional[List[str]] = None,
|
||||
):
|
||||
if force_uc_zero_embeddings is None:
|
||||
force_uc_zero_embeddings = []
|
||||
@@ -181,7 +176,7 @@ class GeneralConditioner(nn.Module):
|
||||
for embedder in self.embedders:
|
||||
ucg_rates.append(embedder.ucg_rate)
|
||||
embedder.ucg_rate = 0.0
|
||||
c = self(batch_c)
|
||||
c = self(batch_c, force_cond_zero_embeddings)
|
||||
uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
|
||||
|
||||
for embedder, rate in zip(self.embedders, ucg_rates):
|
||||
@@ -201,12 +196,6 @@ class InceptionV3(nn.Module):
|
||||
self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)
|
||||
|
||||
def forward(self, inp):
|
||||
# inp = kornia.geometry.resize(inp, (299, 299),
|
||||
# interpolation='bicubic',
|
||||
# align_corners=False,
|
||||
# antialias=True)
|
||||
# inp = inp.clamp(min=-1, max=1)
|
||||
|
||||
outp = self.model(inp)
|
||||
|
||||
if len(outp) == 1:
|
||||
@@ -277,7 +266,6 @@ class FrozenT5Embedder(AbstractEmbModel):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# @autocast
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
@@ -597,11 +585,12 @@ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
|
||||
repeat_to_max_len=False,
|
||||
num_image_crops=0,
|
||||
output_tokens=False,
|
||||
init_device=None,
|
||||
):
|
||||
super().__init__()
|
||||
model, _, _ = open_clip.create_model_and_transforms(
|
||||
arch,
|
||||
device=torch.device("cpu"),
|
||||
device=torch.device(default(init_device, "cpu")),
|
||||
pretrained=version,
|
||||
)
|
||||
del model.transformer
|
||||
@@ -914,7 +903,6 @@ class LowScaleEncoder(nn.Module):
|
||||
z = self.q_sample(z, noise_level)
|
||||
if self.out_size is not None:
|
||||
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")
|
||||
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
||||
return z, noise_level
|
||||
|
||||
def decode(self, z):
|
||||
@@ -958,3 +946,101 @@ class GaussianEncoder(Encoder, AbstractEmbModel):
|
||||
if self.flatten_output:
|
||||
z = rearrange(z, "b c h w -> b (h w ) c")
|
||||
return log, z
|
||||
|
||||
|
||||
class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
|
||||
def __init__(
|
||||
self,
|
||||
n_cond_frames: int,
|
||||
n_copies: int,
|
||||
encoder_config: dict,
|
||||
sigma_sampler_config: Optional[dict] = None,
|
||||
sigma_cond_config: Optional[dict] = None,
|
||||
is_ae: bool = False,
|
||||
scale_factor: float = 1.0,
|
||||
disable_encoder_autocast: bool = False,
|
||||
en_and_decode_n_samples_a_time: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_cond_frames = n_cond_frames
|
||||
self.n_copies = n_copies
|
||||
self.encoder = instantiate_from_config(encoder_config)
|
||||
self.sigma_sampler = (
|
||||
instantiate_from_config(sigma_sampler_config)
|
||||
if sigma_sampler_config is not None
|
||||
else None
|
||||
)
|
||||
self.sigma_cond = (
|
||||
instantiate_from_config(sigma_cond_config)
|
||||
if sigma_cond_config is not None
|
||||
else None
|
||||
)
|
||||
self.is_ae = is_ae
|
||||
self.scale_factor = scale_factor
|
||||
self.disable_encoder_autocast = disable_encoder_autocast
|
||||
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
|
||||
|
||||
def forward(
|
||||
self, vid: torch.Tensor
|
||||
) -> Union[
|
||||
torch.Tensor,
|
||||
Tuple[torch.Tensor, torch.Tensor],
|
||||
Tuple[torch.Tensor, dict],
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], dict],
|
||||
]:
|
||||
if self.sigma_sampler is not None:
|
||||
b = vid.shape[0] // self.n_cond_frames
|
||||
sigmas = self.sigma_sampler(b).to(vid.device)
|
||||
if self.sigma_cond is not None:
|
||||
sigma_cond = self.sigma_cond(sigmas)
|
||||
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies)
|
||||
sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames)
|
||||
noise = torch.randn_like(vid)
|
||||
vid = vid + noise * append_dims(sigmas, vid.ndim)
|
||||
|
||||
with torch.autocast("cuda", enabled=not self.disable_encoder_autocast):
|
||||
n_samples = (
|
||||
self.en_and_decode_n_samples_a_time
|
||||
if self.en_and_decode_n_samples_a_time is not None
|
||||
else vid.shape[0]
|
||||
)
|
||||
n_rounds = math.ceil(vid.shape[0] / n_samples)
|
||||
all_out = []
|
||||
for n in range(n_rounds):
|
||||
if self.is_ae:
|
||||
out = self.encoder.encode(vid[n * n_samples : (n + 1) * n_samples])
|
||||
else:
|
||||
out = self.encoder(vid[n * n_samples : (n + 1) * n_samples])
|
||||
all_out.append(out)
|
||||
|
||||
vid = torch.cat(all_out, dim=0)
|
||||
vid *= self.scale_factor
|
||||
|
||||
vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames)
|
||||
vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies)
|
||||
|
||||
return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid
|
||||
|
||||
return return_val
|
||||
|
||||
|
||||
class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel):
|
||||
def __init__(
|
||||
self,
|
||||
open_clip_embedding_config: Dict,
|
||||
n_cond_frames: int,
|
||||
n_copies: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_cond_frames = n_cond_frames
|
||||
self.n_copies = n_copies
|
||||
self.open_clip = instantiate_from_config(open_clip_embedding_config)
|
||||
|
||||
def forward(self, vid):
|
||||
vid = self.open_clip(vid)
|
||||
vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames)
|
||||
vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies)
|
||||
|
||||
return vid
|
||||
|
||||
301
sgm/modules/video_attention.py
Normal file
301
sgm/modules/video_attention.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import torch
|
||||
|
||||
from ..modules.attention import *
|
||||
from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding
|
||||
|
||||
|
||||
class TimeMixSequential(nn.Sequential):
|
||||
def forward(self, x, context=None, timesteps=None):
|
||||
for layer in self:
|
||||
x = layer(x, context, timesteps)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VideoTransformerBlock(nn.Module):
|
||||
ATTENTION_MODES = {
|
||||
"softmax": CrossAttention,
|
||||
"softmax-xformers": MemoryEfficientCrossAttention,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
timesteps=None,
|
||||
ff_in=False,
|
||||
inner_dim=None,
|
||||
attn_mode="softmax",
|
||||
disable_self_attn=False,
|
||||
disable_temporal_crossattention=False,
|
||||
switch_temporal_ca_to_sa=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
attn_cls = self.ATTENTION_MODES[attn_mode]
|
||||
|
||||
self.ff_in = ff_in or inner_dim is not None
|
||||
if inner_dim is None:
|
||||
inner_dim = dim
|
||||
|
||||
assert int(n_heads * d_head) == inner_dim
|
||||
|
||||
self.is_res = inner_dim == dim
|
||||
|
||||
if self.ff_in:
|
||||
self.norm_in = nn.LayerNorm(dim)
|
||||
self.ff_in = FeedForward(
|
||||
dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff
|
||||
)
|
||||
|
||||
self.timesteps = timesteps
|
||||
self.disable_self_attn = disable_self_attn
|
||||
if self.disable_self_attn:
|
||||
self.attn1 = attn_cls(
|
||||
query_dim=inner_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
context_dim=context_dim,
|
||||
dropout=dropout,
|
||||
) # is a cross-attention
|
||||
else:
|
||||
self.attn1 = attn_cls(
|
||||
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is a self-attention
|
||||
|
||||
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)
|
||||
|
||||
if disable_temporal_crossattention:
|
||||
if switch_temporal_ca_to_sa:
|
||||
raise ValueError
|
||||
else:
|
||||
self.attn2 = None
|
||||
else:
|
||||
self.norm2 = nn.LayerNorm(inner_dim)
|
||||
if switch_temporal_ca_to_sa:
|
||||
self.attn2 = attn_cls(
|
||||
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is a self-attention
|
||||
else:
|
||||
self.attn2 = attn_cls(
|
||||
query_dim=inner_dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
) # is self-attn if context is none
|
||||
|
||||
self.norm1 = nn.LayerNorm(inner_dim)
|
||||
self.norm3 = nn.LayerNorm(inner_dim)
|
||||
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
||||
|
||||
self.checkpoint = checkpoint
|
||||
if self.checkpoint:
|
||||
print(f"{self.__class__.__name__} is using checkpointing")
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None
|
||||
) -> torch.Tensor:
|
||||
if self.checkpoint:
|
||||
return checkpoint(self._forward, x, context, timesteps)
|
||||
else:
|
||||
return self._forward(x, context, timesteps=timesteps)
|
||||
|
||||
def _forward(self, x, context=None, timesteps=None):
|
||||
assert self.timesteps or timesteps
|
||||
assert not (self.timesteps and timesteps) or self.timesteps == timesteps
|
||||
timesteps = self.timesteps or timesteps
|
||||
B, S, C = x.shape
|
||||
x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
|
||||
|
||||
if self.ff_in:
|
||||
x_skip = x
|
||||
x = self.ff_in(self.norm_in(x))
|
||||
if self.is_res:
|
||||
x += x_skip
|
||||
|
||||
if self.disable_self_attn:
|
||||
x = self.attn1(self.norm1(x), context=context) + x
|
||||
else:
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
|
||||
if self.attn2 is not None:
|
||||
if self.switch_temporal_ca_to_sa:
|
||||
x = self.attn2(self.norm2(x)) + x
|
||||
else:
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x_skip = x
|
||||
x = self.ff(self.norm3(x))
|
||||
if self.is_res:
|
||||
x += x_skip
|
||||
|
||||
x = rearrange(
|
||||
x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
|
||||
)
|
||||
return x
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.ff.net[-1].weight
|
||||
|
||||
|
||||
class SpatialVideoTransformer(SpatialTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.0,
|
||||
use_linear=False,
|
||||
context_dim=None,
|
||||
use_spatial_context=False,
|
||||
timesteps=None,
|
||||
merge_strategy: str = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
time_context_dim=None,
|
||||
ff_in=False,
|
||||
checkpoint=False,
|
||||
time_depth=1,
|
||||
attn_mode="softmax",
|
||||
disable_self_attn=False,
|
||||
disable_temporal_crossattention=False,
|
||||
max_time_embed_period: int = 10000,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=depth,
|
||||
dropout=dropout,
|
||||
attn_type=attn_mode,
|
||||
use_checkpoint=checkpoint,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
disable_self_attn=disable_self_attn,
|
||||
)
|
||||
self.time_depth = time_depth
|
||||
self.depth = depth
|
||||
self.max_time_embed_period = max_time_embed_period
|
||||
|
||||
time_mix_d_head = d_head
|
||||
n_time_mix_heads = n_heads
|
||||
|
||||
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
|
||||
|
||||
inner_dim = n_heads * d_head
|
||||
if use_spatial_context:
|
||||
time_context_dim = context_dim
|
||||
|
||||
self.time_stack = nn.ModuleList(
|
||||
[
|
||||
VideoTransformerBlock(
|
||||
inner_dim,
|
||||
n_time_mix_heads,
|
||||
time_mix_d_head,
|
||||
dropout=dropout,
|
||||
context_dim=time_context_dim,
|
||||
timesteps=timesteps,
|
||||
checkpoint=checkpoint,
|
||||
ff_in=ff_in,
|
||||
inner_dim=time_mix_inner_dim,
|
||||
attn_mode=attn_mode,
|
||||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
]
|
||||
)
|
||||
|
||||
assert len(self.time_stack) == len(self.transformer_blocks)
|
||||
|
||||
self.use_spatial_context = use_spatial_context
|
||||
self.in_channels = in_channels
|
||||
|
||||
time_embed_dim = self.in_channels * 4
|
||||
self.time_pos_embed = nn.Sequential(
|
||||
linear(self.in_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, self.in_channels),
|
||||
)
|
||||
|
||||
self.time_mixer = AlphaBlender(
|
||||
alpha=merge_factor, merge_strategy=merge_strategy
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
time_context: Optional[torch.Tensor] = None,
|
||||
timesteps: Optional[int] = None,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
_, _, h, w = x.shape
|
||||
x_in = x
|
||||
spatial_context = None
|
||||
if exists(context):
|
||||
spatial_context = context
|
||||
|
||||
if self.use_spatial_context:
|
||||
assert (
|
||||
context.ndim == 3
|
||||
), f"n dims of spatial context should be 3 but are {context.ndim}"
|
||||
|
||||
time_context = context
|
||||
time_context_first_timestep = time_context[::timesteps]
|
||||
time_context = repeat(
|
||||
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
|
||||
)
|
||||
elif time_context is not None and not self.use_spatial_context:
|
||||
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
|
||||
if time_context.ndim == 2:
|
||||
time_context = rearrange(time_context, "b c -> b 1 c")
|
||||
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
|
||||
num_frames = torch.arange(timesteps, device=x.device)
|
||||
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||
t_emb = timestep_embedding(
|
||||
num_frames,
|
||||
self.in_channels,
|
||||
repeat_only=False,
|
||||
max_period=self.max_time_embed_period,
|
||||
)
|
||||
emb = self.time_pos_embed(t_emb)
|
||||
emb = emb[:, None, :]
|
||||
|
||||
for it_, (block, mix_block) in enumerate(
|
||||
zip(self.transformer_blocks, self.time_stack)
|
||||
):
|
||||
x = block(
|
||||
x,
|
||||
context=spatial_context,
|
||||
)
|
||||
|
||||
x_mix = x
|
||||
x_mix = x_mix + emb
|
||||
|
||||
x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)
|
||||
x = self.time_mixer(
|
||||
x_spatial=x,
|
||||
x_temporal=x_mix,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
out = x + x_in
|
||||
return out
|
||||
27
sgm/util.py
27
sgm/util.py
@@ -246,3 +246,30 @@ def get_configs_path() -> str:
|
||||
if os.path.isdir(candidate):
|
||||
return candidate
|
||||
raise FileNotFoundError(f"Could not find SGM configs in {candidates}")
|
||||
|
||||
|
||||
def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
|
||||
"""
|
||||
Will return the result of a recursive get attribute call.
|
||||
E.g.:
|
||||
a.b.c
|
||||
= getattr(getattr(a, "b"), "c")
|
||||
= get_nested_attribute(a, "b.c")
|
||||
If any part of the attribute call is an integer x with current obj a, will
|
||||
try to call a[x] instead of a.x first.
|
||||
"""
|
||||
attributes = attribute_path.split(".")
|
||||
if depth is not None and depth > 0:
|
||||
attributes = attributes[:depth]
|
||||
assert len(attributes) > 0, "At least one attribute should be selected"
|
||||
current_attribute = obj
|
||||
current_key = None
|
||||
for level, attribute in enumerate(attributes):
|
||||
current_key = ".".join(attributes[: level + 1])
|
||||
try:
|
||||
id_ = int(attribute)
|
||||
current_attribute = current_attribute[id_]
|
||||
except ValueError:
|
||||
current_attribute = getattr(current_attribute, attribute)
|
||||
|
||||
return (current_attribute, current_key) if return_key else current_attribute
|
||||
|
||||
Reference in New Issue
Block a user