soon is now

This commit is contained in:
Andreas Blattmann
2023-06-22 09:53:12 -07:00
commit 081e0d4629
63 changed files with 10916 additions and 0 deletions

7
.gitignore vendored Normal file
View File

@@ -0,0 +1,7 @@
.pt2
.pt2_2
.pt13
*.egg-info
build
/outputs
/checkpoints

75
LICENSE Normal file
View File

@@ -0,0 +1,75 @@
SDXL 0.9 RESEARCH LICENSE AGREEMENT
Copyright (c) Stability AI Ltd.
This License Agreement (as may be amended in accordance with this License Agreement, “License”), between you, or your employer or other entity (if you are entering into this agreement on behalf of your employer or other entity) (“Licensee” or “you”) and Stability AI Ltd. (“Stability AI” or “we”) applies to your use of any computer program, algorithm, source code, object code, or software that is made available by Stability AI under this License (“Software”) and any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software (“Documentation”).
By clicking “I Accept” below or by using the Software, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software or Documentation (collectively, the “Software Products”), and you must immediately cease using the Software Products. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to Stability AI that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products on behalf of your employer or other entity.
1. LICENSE GRANT
a. Subject to your compliance with the Documentation and Sections 2, 3, and 5, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AIs copyright interests to reproduce, distribute, and create derivative works of the Software solely for your non-commercial research purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Stability AIs prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License.
b. You may make a reasonable number of copies of the Documentation solely for use in connection with the license to the Software granted above.
c. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Stability AI and its licensors reserve all rights not expressly granted by this License.
2. RESTRICTIONS
You will not, and will not permit, assist or cause any third party to:
a. use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes or in the service of nuclear technology, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing;
b. alter or remove copyright and other proprietary notices which appear on or in the Software Products;
c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Stability AI in connection with the Software, or to circumvent or remove any usage restrictions, or to enable functionality disabled by Stability AI; or
d. offer or impose any terms on the Software Products that alter, restrict, or are inconsistent with the terms of this License.
e. 1) violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”); 2) directly or indirectly export, re-export, provide, or otherwise transfer Software Products: (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download Software Products if you or they are: (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods.
3. ATTRIBUTION
Together with any copies of the Software Products (as well as derivative works thereof or works incorporating the Software Products) that you distribute, you must provide (i) a copy of this License, and (ii) the following attribution notice: “SDXL 0.9 is licensed under the SDXL Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.”
4. DISCLAIMERS
THE SOFTWARE PRODUCTS ARE PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. STABILITY AIEXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE SOFTWARE PRODUCTS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. STABILITY AI MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE SOFTWARE PRODUCTS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.
5. LIMITATION OF LIABILITY
TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL STABILITY AI BE LIABLE TO YOU (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF STABILITY AI HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE SOFTWARE PRODUCTS, THEIR CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “SOFTWARE MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE SOFTWARE MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUALS PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE SOFTWARE MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE SOFTWARE MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
6. INDEMNIFICATION
You will indemnify, defend and hold harmless Stability AI and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Stability AI Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys fees) incurred by any Stability AI Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to: (a) your access to or use of the Software Products (as well as any results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Stability AI Parties of any such Claims, and cooperate with Stability AI Parties in defending such Claims. You will also grant the Stability AI Parties sole control of the defense or settlement, at Stability AIs sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Stability AI or the other Stability AI Parties.
7. TERMINATION; SURVIVAL
a. This License will automatically terminate upon any breach by you of the terms of this License.
b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
c. The following sections survive termination of this License: 2 (Restrictions), 3 (Attribution), 4 (Disclaimers), 5 (Limitation on Liability), 6 (Indemnification) 7 (Termination; Survival), 8 (Third Party Materials), 9 (Trademarks), 10 (Applicable Law; Dispute Resolution), and 11 (Miscellaneous).
8. THIRD PARTY MATERIALS
The Software Products may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Stability AI does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.
9. TRADEMARKS
Licensee has not been granted any trademark license as part of this License and may not use any name or mark associated with Stability AI without the prior written permission of Stability AI, except to the extent necessary to make the reference required by the “ATTRIBUTION” section of this Agreement.
10. APPLICABLE LAW; DISPUTE RESOLUTION
This License will be governed and construed under the laws of the State of California without regard to conflicts of law provisions. Any suit or proceeding arising out of or relating to this License will be brought in the federal or state courts, as applicable, in San Mateo County, California, and each party irrevocably submits to the jurisdiction and venue of such courts.
11. MISCELLANEOUS
If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Stability AI to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Stability AI regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Stability AI regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Stability AI.

187
README.md Normal file
View File

@@ -0,0 +1,187 @@
# Generative Models by Stability AI
![sample1](assets/000.jpg)
## News
**June 22, 2023**
- We are releasing two new diffusion models:
- `SD-XL 0.9-base`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip) and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses the OpenCLIP model.
- `SD-XL 0.9-refiner`: The refiner has been trained to denoise small noise levels of high quality data and as such is not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
**We plan to do a full release soon (July).**
## The codebase
### General Philosophy
Modularity is king. This repo implements a config-driven approach where we build and combine submodules by calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
### Changelog from the old `ldm` codebase
For training, we use [pytorch-lightning](https://www.pytorchlightning.ai/index.html), but it should be easy to use other training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`, now `DiffusionEngine`) has been cleaned up:
- No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial conditionings, and all combinations thereof) in a single class: `GeneralConditioner`, see `sgm/modules/encoders/modules.py`.
- We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model.
- We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable change is probably now the option to train continuous time models):
* Discrete times models (denoisers) are simply a special case of continuous time models (denoisers); see `sgm/modules/diffusionmodules/denoiser.py`.
* The following features are now independent: weighting of the diffusion loss function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
- Autoencoding models have also been cleaned up.
## Installation:
<a name="installation"></a>
#### 1. Clone the repo
```shell
git clone git@github.com:Stability-AI/generative-models.git
cd generative-models
```
#### 2. Setting up the virtualenv
This is assuming you have navigated to the `generative-models` root after cloning it.
**NOTE:** This is tested under `python3.8` and `python3.10`. For other python versions, you might encounter version conflicts.
**PyTorch 1.13**
```shell
# install required packages from pypi
python3 -m venv .pt1
source .pt1/bin/activate
pip3 install wheel
pip3 install -r requirements_pt13.txt
```
**PyTorch 2.0**
```shell
# install required packages from pypi
python3 -m venv .pt2
source .pt2/bin/activate
pip3 install wheel
pip3 install -r requirements_pt2.txt
```
## Inference:
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`. The following models are currently supported:
- [SD-XL 0.9-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)
- [SD-XL 0.9-refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
- [SD 2.1-512](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.safetensors)
- [SD 2.1-768](https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors)
**Weights for SDXL**:
If you would like to access these models for your research, please apply using one of the following links:
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
This means that you can apply for any of the two links - and if you are granted - you can access both.
Please log in to your HuggingFace Account with your organization email to request access.
After obtaining the weights, place them into `checkpoints/`.
Next, start the demo using
```
streamlit run scripts/demo/sampling.py --server.port <your_port>
```
### Invisible Watermark Detection
Images generated with our code use the
[invisible-watermark](https://github.com/ShieldMnt/invisible-watermark/)
library to embed an invisible watermark into the model output. We also provide
a script to easily detect that watermark. Please note that this watermark is
not the same as in previous Stable Diffusion 1.x/2.x versions.
To run the script you need to either have a working installation as above or
try an _experimental_ import using only a minimal amount of packages:
```bash
python -m venv .detect
source .detect/bin/activate
pip install "numpy>=1.17" "PyWavelets>=1.1.1" "opencv-python>=4.1.0.25"
pip install --no-deps invisible-watermark
```
To run the script you need to have a working installation as above. The script
is then useable in the following ways (don't forget to activate your
virtual environment beforehand, e.g. `source .pt1/bin/activate`):
```bash
# test a single file
python scripts/demo/detect.py <your filename here>
# test multiple files at once
python scripts/demo/detect.py <filename 1> <filename 2> ... <filename n>
# test all files in a specific folder
python scripts/demo/detect.py <your folder name here>/*
```
## Training:
We are providing example training configs in `configs/example_training`. To launch a training, run
```
python main.py --base configs/<config1.yaml> configs/<config2.yaml>
```
where configs are merged from left to right (later configs overwrite the same values).
This can be used to combine model, training and data configs. However, all of them can also be
defined in a single config. For example, to run a class-conditional pixel-based diffusion model training on MNIST,
run
```bash
python main.py --base configs/example_training/toy/mnist_cond.yaml
```
**NOTE 1:** Using the non-toy-dataset configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml` and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depdending on the used dataset (which is expected to stored in tar-file in the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search for comments containing `USER:` in the respective config.
**NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`, only `pytorch1.13` is supported.
**NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done for the provided text-to-image configs.
### Building New Diffusion Models
#### Conditioner
The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of
different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model.
All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free
guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for text-conditioning or `cls` for class-conditioning.
When computing conditionings, the embedder will get `batch[input_key]` as input.
We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated
appropriately.
Note that the order of the embedders in the `conditioner_config` is important.
#### Network
The neural network is set through the `network_config`. This used to be called `unet_config`, which is not general
enough as we plan to experiment with transformer-based diffusion backbones.
#### Loss
The loss is configured through `loss_config`. For standard diffusion model training, you will have to set `sigma_sampler_config`.
#### Sampler config
As discussed above, the sampler is independent of the model. In the `sampler_config`, we set the type of numerical
solver, number of steps, type of discretization, as well as, for example, guidance wrappers for classifier-free
guidance.
### Dataset Handling
For large scale training we recommend using the datapipelines from our [datapipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement and automatically included when following the steps from the [Installation section](#installation).
Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
data keys/values,
e.g.,
```python
example = {"jpg": x, # this is a tensor -1...1 chw
"txt": "a beautiful image"}
```
where we expect images in -1...1, channel-first format.

BIN
assets/000.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 711 KiB

View File

@@ -0,0 +1,115 @@
model:
base_learning_rate: 4.5e-6
target: sgm.models.autoencoder.AutoencodingEngine
params:
input_key: jpg
monitor: val/rec_loss
loss_config:
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
params:
perceptual_weight: 0.25
disc_start: 20001
disc_weight: 0.5
learn_logvar: True
regularization_weights:
kl_loss: 1.0
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: sgm.modules.diffusionmodules.model.Encoder
params:
attn_type: none
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4 ]
num_res_blocks: 4
attn_resolutions: [ ]
dropout: 0.0
decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params:
attn_type: none
double_z: False
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4 ]
num_res_blocks: 4
attn_resolutions: [ ]
dropout: 0.0
data:
target: sgm.data.dataset.StableDataModuleFromConfig
params:
train:
datapipeline:
urls:
- "DATA-PATH"
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000
decoders:
- "pil"
postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: 'jpg'
transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.ToTensor
- target: sdata.mappers.Rescaler
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
params:
h_key: height
w_key: width
loader:
batch_size: 8
num_workers: 4
lightning:
strategy:
target: pytorch_lightning.strategies.DDPStrategy
params:
find_unused_parameters: True
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 50000
image_logger:
target: main.ImageLogger
params:
enable_autocast: False
batch_frequency: 1000
max_images: 8
increase_log_steps: True
trainer:
devices: 0,
limit_val_batches: 50
benchmark: True
accumulate_grad_batches: 1
val_check_interval: 10000

View File

@@ -0,0 +1,188 @@
model:
base_learning_rate: 1.0e-4
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
log_keys:
- cls
scheduler_config:
target: sgm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [10000]
cycle_lengths: [10000000000000]
f_start: [1.e-6]
f_max: [1.]
f_min: [1.]
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
in_channels: 4
out_channels: 4
model_channels: 256
attention_resolutions: [1, 2, 4]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
num_classes: sequential
adm_in_channels: 1024
use_spatial_transformer: true
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: True
input_key: cls
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder
params:
add_sequence_dim: True # will be used through crossattn then
embed_dim: 1024
n_classes: 1000
# vector cond
- is_trainable: False
ucg_rate: 0.2
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
ckpt_path: CKPT_PATH
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
num_idx: 1000
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
num_steps: 50
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
params:
scale: 5.0
data:
target: sgm.data.dataset.StableDataModuleFromConfig
params:
train:
datapipeline:
urls:
# USER: adapt this path the root of your custom dataset
- "DATA_PATH"
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
decoders:
- "pil"
postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: 'jpg' # USER: you might wanna adapt this for your custom dataset
transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.ToTensor
- target: sdata.mappers.Rescaler
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
params:
h_key: height # USER: you might wanna adapt this for your custom dataset
w_key: width # USER: you might wanna adapt this for your custom dataset
loader:
batch_size: 64
num_workers: 6
lightning:
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 25000
image_logger:
target: main.ImageLogger
params:
disabled: False
enable_autocast: False
batch_frequency: 1000
max_images: 8
increase_log_steps: True
log_first_step: False
log_images_kwargs:
use_ema_scope: False
N: 8
n_rows: 2
trainer:
devices: 0,
benchmark: True
num_sanity_val_steps: 0
accumulate_grad_batches: 1
max_epochs: 1000

View File

@@ -0,0 +1,99 @@
model:
base_learning_rate: 1.0e-4
target: sgm.models.diffusion.DiffusionEngine
params:
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
params:
sigma_data: 1.0
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
params:
sigma_data: 1.0
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
in_channels: 3
out_channels: 3
model_channels: 32
attention_resolutions: []
num_res_blocks: 4
channel_mult: [1, 2, 2]
num_head_channels: 32
num_classes: sequential
adm_in_channels: 128
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: True
input_key: cls
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder
params:
embed_dim: 128
n_classes: 10
first_stage_config:
target: sgm.models.autoencoder.IdentityFirstStage
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
num_steps: 50
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
params:
scale: 3.0
data:
target: sgm.data.cifar10.CIFAR10Loader
params:
batch_size: 512
num_workers: 1
lightning:
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 25000
image_logger:
target: main.ImageLogger
params:
disabled: False
batch_frequency: 1000
max_images: 64
increase_log_steps: True
log_first_step: False
log_images_kwargs:
use_ema_scope: False
N: 64
n_rows: 8
trainer:
devices: 0,
benchmark: True
num_sanity_val_steps: 0
accumulate_grad_batches: 1
max_epochs: 20

View File

@@ -0,0 +1,80 @@
model:
base_learning_rate: 1.0e-4
target: sgm.models.diffusion.DiffusionEngine
params:
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
params:
sigma_data: 1.0
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
params:
sigma_data: 1.0
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
in_channels: 1
out_channels: 1
model_channels: 32
attention_resolutions: []
num_res_blocks: 4
channel_mult: [1, 2, 2]
num_head_channels: 32
first_stage_config:
target: sgm.models.autoencoder.IdentityFirstStage
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
num_steps: 50
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
data:
target: sgm.data.mnist.MNISTLoader
params:
batch_size: 512
num_workers: 1
lightning:
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 25000
image_logger:
target: main.ImageLogger
params:
disabled: False
batch_frequency: 1000
max_images: 64
increase_log_steps: False
log_first_step: False
log_images_kwargs:
use_ema_scope: False
N: 64
n_rows: 8
trainer:
devices: 0,
benchmark: True
num_sanity_val_steps: 0
accumulate_grad_batches: 1
max_epochs: 10

View File

@@ -0,0 +1,99 @@
model:
base_learning_rate: 1.0e-4
target: sgm.models.diffusion.DiffusionEngine
params:
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
params:
sigma_data: 1.0
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
params:
sigma_data: 1.0
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
in_channels: 1
out_channels: 1
model_channels: 32
attention_resolutions: [ ]
num_res_blocks: 4
channel_mult: [ 1, 2, 2 ]
num_head_channels: 32
num_classes: sequential
adm_in_channels: 128
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: True
input_key: "cls"
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder
params:
embed_dim: 128
n_classes: 10
first_stage_config:
target: sgm.models.autoencoder.IdentityFirstStage
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
num_steps: 50
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
params:
scale: 3.0
data:
target: sgm.data.mnist.MNISTLoader
params:
batch_size: 512
num_workers: 1
lightning:
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 25000
image_logger:
target: main.ImageLogger
params:
disabled: False
batch_frequency: 1000
max_images: 16
increase_log_steps: True
log_first_step: False
log_images_kwargs:
use_ema_scope: False
N: 16
n_rows: 4
trainer:
devices: 0,
benchmark: True
num_sanity_val_steps: 0
accumulate_grad_batches: 1
max_epochs: 20

View File

@@ -0,0 +1,104 @@
model:
base_learning_rate: 1.0e-4
target: sgm.models.diffusion.DiffusionEngine
params:
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
in_channels: 1
out_channels: 1
model_channels: 32
attention_resolutions: [ ]
num_res_blocks: 4
channel_mult: [ 1, 2, 2 ]
num_head_channels: 32
num_classes: sequential
adm_in_channels: 128
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: True
input_key: "cls"
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder
params:
embed_dim: 128
n_classes: 10
first_stage_config:
target: sgm.models.autoencoder.IdentityFirstStage
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
num_idx: 1000
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
num_steps: 50
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
params:
scale: 5.0
data:
target: sgm.data.mnist.MNISTLoader
params:
batch_size: 512
num_workers: 1
lightning:
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 25000
image_logger:
target: main.ImageLogger
params:
disabled: False
batch_frequency: 1000
max_images: 16
increase_log_steps: True
log_first_step: False
log_images_kwargs:
use_ema_scope: False
N: 16
n_rows: 4
trainer:
devices: 0,
benchmark: True
num_sanity_val_steps: 0
accumulate_grad_batches: 1
max_epochs: 20

View File

@@ -0,0 +1,104 @@
model:
base_learning_rate: 1.0e-4
target: sgm.models.diffusion.DiffusionEngine
params:
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
params:
sigma_data: 1.0
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
params:
sigma_data: 1.0
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
in_channels: 1
out_channels: 1
model_channels: 32
attention_resolutions: []
num_res_blocks: 4
channel_mult: [1, 2, 2]
num_head_channels: 32
num_classes: "sequential"
adm_in_channels: 128
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: True
input_key: "cls"
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder
params:
embed_dim: 128
n_classes: 10
first_stage_config:
target: sgm.models.autoencoder.IdentityFirstStage
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
num_steps: 50
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
params:
scale: 3.0
loss_config:
target: sgm.modules.diffusionmodules.StandardDiffusionLoss
params:
type: l1
data:
target: sgm.data.mnist.MNISTLoader
params:
batch_size: 512
num_workers: 1
lightning:
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 25000
image_logger:
target: main.ImageLogger
params:
disabled: False
batch_frequency: 1000
max_images: 64
increase_log_steps: True
log_first_step: False
log_images_kwargs:
use_ema_scope: False
N: 64
n_rows: 8
trainer:
devices: 0,
benchmark: True
num_sanity_val_steps: 0
accumulate_grad_batches: 1
max_epochs: 20

View File

@@ -0,0 +1,101 @@
model:
base_learning_rate: 1.0e-4
target: sgm.models.diffusion.DiffusionEngine
params:
use_ema: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
params:
sigma_data: 1.0
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
params:
sigma_data: 1.0
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
in_channels: 1
out_channels: 1
model_channels: 32
attention_resolutions: []
num_res_blocks: 4
channel_mult: [1, 2, 2]
num_head_channels: 32
num_classes: sequential
adm_in_channels: 128
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: True
input_key: cls
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder
params:
embed_dim: 128
n_classes: 10
first_stage_config:
target: sgm.models.autoencoder.IdentityFirstStage
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
num_steps: 50
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
params:
scale: 3.0
data:
target: sgm.data.mnist.MNISTLoader
params:
batch_size: 512
num_workers: 1
lightning:
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 25000
image_logger:
target: main.ImageLogger
params:
disabled: False
batch_frequency: 1000
max_images: 64
increase_log_steps: True
log_first_step: False
log_images_kwargs:
use_ema_scope: False
N: 64
n_rows: 8
trainer:
devices: 0,
benchmark: True
num_sanity_val_steps: 0
accumulate_grad_batches: 1
max_epochs: 20

View File

@@ -0,0 +1,185 @@
model:
base_learning_rate: 1.0e-4
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
log_keys:
- txt
scheduler_config:
target: sgm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ]
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 1, 2, 4 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64
num_classes: sequential
adm_in_channels: 1792
num_heads: 1
use_spatial_transformer: true
transformer_depth: 1
context_dim: 768
spatial_transformer_attn_type: softmax-xformers
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: True
input_key: txt
ucg_rate: 0.1
legacy_ucg_value: ""
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
always_return_pooled: True
# vector cond
- is_trainable: False
ucg_rate: 0.1
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
ucg_rate: 0.1
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
ckpt_path: CKPT_PATH
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
num_idx: 1000
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
num_steps: 50
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
params:
scale: 7.5
data:
target: sgm.data.dataset.StableDataModuleFromConfig
params:
train:
datapipeline:
urls:
# USER: adapt this path the root of your custom dataset
- "DATA_PATH"
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
decoders:
- "pil"
postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: 'jpg' # USER: you might wanna adapt this for your custom dataset
transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.ToTensor
- target: sdata.mappers.Rescaler
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
# USER: you might wanna use non-default parameters due to your custom dataset
loader:
batch_size: 64
num_workers: 6
lightning:
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 25000
image_logger:
target: main.ImageLogger
params:
disabled: False
enable_autocast: False
batch_frequency: 1000
max_images: 8
increase_log_steps: True
log_first_step: False
log_images_kwargs:
use_ema_scope: False
N: 8
n_rows: 2
trainer:
devices: 0,
benchmark: True
num_sanity_val_steps: 0
accumulate_grad_batches: 1
max_epochs: 1000

View File

@@ -0,0 +1,186 @@
model:
base_learning_rate: 1.0e-4
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
log_keys:
- txt
scheduler_config:
target: sgm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ]
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 1, 2, 4 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64
num_classes: sequential
adm_in_channels: 1792
num_heads: 1
use_spatial_transformer: true
transformer_depth: 1
context_dim: 768
spatial_transformer_attn_type: softmax-xformers
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: True
input_key: txt
ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
always_return_pooled: True
# vector cond
- is_trainable: False
ucg_rate: 0.1
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
ucg_rate: 0.1
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
ckpt_path: CKPT_PATH
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
num_idx: 1000
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
num_steps: 50
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
params:
scale: 7.5
data:
target: sgm.data.dataset.StableDataModuleFromConfig
params:
train:
datapipeline:
urls:
# USER: adapt this path the root of your custom dataset
- "DATA_PATH"
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000
decoders:
- "pil"
postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: 'jpg' # USER: you might wanna adapt this for your custom dataset
transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.ToTensor
- target: sdata.mappers.Rescaler
# USER: you might wanna use non-default parameters due to your custom dataset
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
# USER: you might wanna use non-default parameters due to your custom dataset
loader:
batch_size: 64
num_workers: 6
lightning:
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 25000
image_logger:
target: main.ImageLogger
params:
disabled: False
enable_autocast: False
batch_frequency: 1000
max_images: 8
increase_log_steps: True
log_first_step: False
log_images_kwargs:
use_ema_scope: False
N: 8
n_rows: 2
trainer:
devices: 0,
benchmark: True
num_sanity_val_steps: 0
accumulate_grad_batches: 1
max_epochs: 1000

View File

@@ -0,0 +1,66 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2, 1]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: true
layer: penultimate
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@@ -0,0 +1,66 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.VWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2, 1]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: true
layer: penultimate
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@@ -0,0 +1,98 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@@ -0,0 +1,91 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2560
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 384
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 4
context_dim: [1280, 1280, 1280, 1280] # 1280
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
legacy: False
freeze: True
layer: penultimate
always_return_pooled: True
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: aesthetic_score
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by one
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

BIN
data/DejaVuSans.ttf Normal file

Binary file not shown.

947
main.py Normal file
View File

@@ -0,0 +1,947 @@
import argparse
import datetime
import glob
import inspect
import os
import sys
from inspect import Parameter
from typing import Union
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
import wandb
from PIL import Image
from matplotlib import pyplot as plt
from natsort import natsorted
from omegaconf import OmegaConf
from packaging import version
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_only
from sgm.util import (
exists,
instantiate_from_config,
isheatmap,
)
MULTINODE_HACKS = True
def default_trainer_args():
argspec = dict(inspect.signature(Trainer.__init__).parameters)
argspec.pop("self")
default_args = {
param: argspec[param].default
for param in argspec
if argspec[param] != Parameter.empty
}
return default_args
def get_parser(**parser_kwargs):
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument(
"-n",
"--name",
type=str,
const=True,
default="",
nargs="?",
help="postfix for logdir",
)
parser.add_argument(
"--no_date",
type=str2bool,
nargs="?",
const=True,
default=False,
help="if True, skip date generation for logdir and only use naming via opt.base or opt.name (+ opt.postfix, optionally)",
)
parser.add_argument(
"-r",
"--resume",
type=str,
const=True,
default="",
nargs="?",
help="resume from logdir or checkpoint in logdir",
)
parser.add_argument(
"-b",
"--base",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
default=list(),
)
parser.add_argument(
"-t",
"--train",
type=str2bool,
const=True,
default=True,
nargs="?",
help="train",
)
parser.add_argument(
"--no-test",
type=str2bool,
const=True,
default=False,
nargs="?",
help="disable test",
)
parser.add_argument(
"-p", "--project", help="name of new or path to existing project"
)
parser.add_argument(
"-d",
"--debug",
type=str2bool,
nargs="?",
const=True,
default=False,
help="enable post-mortem debugging",
)
parser.add_argument(
"-s",
"--seed",
type=int,
default=23,
help="seed for seed_everything",
)
parser.add_argument(
"-f",
"--postfix",
type=str,
default="",
help="post-postfix for default name",
)
parser.add_argument(
"--projectname",
type=str,
default="stablediffusion",
)
parser.add_argument(
"-l",
"--logdir",
type=str,
default="logs",
help="directory for logging dat shit",
)
parser.add_argument(
"--scale_lr",
type=str2bool,
nargs="?",
const=True,
default=False,
help="scale base-lr by ngpu * batch_size * n_accumulate",
)
parser.add_argument(
"--legacy_naming",
type=str2bool,
nargs="?",
const=True,
default=False,
help="name run based on config file name if true, else by whole path",
)
parser.add_argument(
"--enable_tf32",
type=str2bool,
nargs="?",
const=True,
default=False,
help="enables the TensorFloat32 format both for matmuls and cuDNN for pytorch 1.12",
)
parser.add_argument(
"--startup",
type=str,
default=None,
help="Startuptime from distributed script",
)
parser.add_argument(
"--wandb",
type=str2bool,
nargs="?",
const=True,
default=False, # TODO: later default to True
help="log to wandb",
)
parser.add_argument(
"--no_base_name",
type=str2bool,
nargs="?",
const=True,
default=False, # TODO: later default to True
help="log to wandb",
)
if version.parse(torch.__version__) >= version.parse("2.0.0"):
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="single checkpoint file to resume from",
)
default_args = default_trainer_args()
for key in default_args:
parser.add_argument("--" + key, default=default_args[key])
return parser
def get_checkpoint_name(logdir):
ckpt = os.path.join(logdir, "checkpoints", "last**.ckpt")
ckpt = natsorted(glob.glob(ckpt))
print('available "last" checkpoints:')
print(ckpt)
if len(ckpt) > 1:
print("got most recent checkpoint")
ckpt = sorted(ckpt, key=lambda x: os.path.getmtime(x))[-1]
print(f"Most recent ckpt is {ckpt}")
with open(os.path.join(logdir, "most_recent_ckpt.txt"), "w") as f:
f.write(ckpt + "\n")
try:
version = int(ckpt.split("/")[-1].split("-v")[-1].split(".")[0])
except Exception as e:
print("version confusion but not bad")
print(e)
version = 1
# version = last_version + 1
else:
# in this case, we only have one "last.ckpt"
ckpt = ckpt[0]
version = 1
melk_ckpt_name = f"last-v{version}.ckpt"
print(f"Current melk ckpt name: {melk_ckpt_name}")
return ckpt, melk_ckpt_name
class SetupCallback(Callback):
def __init__(
self,
resume,
now,
logdir,
ckptdir,
cfgdir,
config,
lightning_config,
debug,
ckpt_name=None,
):
super().__init__()
self.resume = resume
self.now = now
self.logdir = logdir
self.ckptdir = ckptdir
self.cfgdir = cfgdir
self.config = config
self.lightning_config = lightning_config
self.debug = debug
self.ckpt_name = ckpt_name
def on_exception(self, trainer: pl.Trainer, pl_module, exception):
if not self.debug and trainer.global_rank == 0:
print("Summoning checkpoint.")
if self.ckpt_name is None:
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
else:
ckpt_path = os.path.join(self.ckptdir, self.ckpt_name)
trainer.save_checkpoint(ckpt_path)
def on_fit_start(self, trainer, pl_module):
if trainer.global_rank == 0:
# Create logdirs and save configs
os.makedirs(self.logdir, exist_ok=True)
os.makedirs(self.ckptdir, exist_ok=True)
os.makedirs(self.cfgdir, exist_ok=True)
if "callbacks" in self.lightning_config:
if (
"metrics_over_trainsteps_checkpoint"
in self.lightning_config["callbacks"]
):
os.makedirs(
os.path.join(self.ckptdir, "trainstep_checkpoints"),
exist_ok=True,
)
print("Project config")
print(OmegaConf.to_yaml(self.config))
if MULTINODE_HACKS:
import time
time.sleep(5)
OmegaConf.save(
self.config,
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
)
print("Lightning config")
print(OmegaConf.to_yaml(self.lightning_config))
OmegaConf.save(
OmegaConf.create({"lightning": self.lightning_config}),
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
)
else:
# ModelCheckpoint callback created log directory --- remove it
if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):
dst, name = os.path.split(self.logdir)
dst = os.path.join(dst, "child_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
try:
os.rename(self.logdir, dst)
except FileNotFoundError:
pass
class ImageLogger(Callback):
def __init__(
self,
batch_frequency,
max_images,
clamp=True,
increase_log_steps=True,
rescale=True,
disabled=False,
log_on_batch_idx=False,
log_first_step=False,
log_images_kwargs=None,
log_before_first_step=False,
enable_autocast=True,
):
super().__init__()
self.enable_autocast = enable_autocast
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
self.disabled = disabled
self.log_on_batch_idx = log_on_batch_idx
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.log_first_step = log_first_step
self.log_before_first_step = log_before_first_step
@rank_zero_only
def log_local(
self,
save_dir,
split,
images,
global_step,
current_epoch,
batch_idx,
pl_module: Union[None, pl.LightningModule] = None,
):
root = os.path.join(save_dir, "images", split)
for k in images:
if isheatmap(images[k]):
fig, ax = plt.subplots()
ax = ax.matshow(
images[k].cpu().numpy(), cmap="hot", interpolation="lanczos"
)
plt.colorbar(ax)
plt.axis("off")
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
k, global_step, current_epoch, batch_idx
)
os.makedirs(root, exist_ok=True)
path = os.path.join(root, filename)
plt.savefig(path)
plt.close()
# TODO: support wandb
else:
grid = torchvision.utils.make_grid(images[k], nrow=4)
if self.rescale:
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
k, global_step, current_epoch, batch_idx
)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
img = Image.fromarray(grid)
img.save(path)
if exists(pl_module):
assert isinstance(
pl_module.logger, WandbLogger
), "logger_log_image only supports WandbLogger currently"
pl_module.logger.log_image(
key=f"{split}/{k}",
images=[
img,
],
step=pl_module.global_step,
)
@rank_zero_only
def log_img(self, pl_module, batch, batch_idx, split="train"):
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
if (
self.check_frequency(check_idx)
and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
and callable(pl_module.log_images)
and
# batch_idx > 5 and
self.max_images > 0
):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
gpu_autocast_kwargs = {
"enabled": self.enable_autocast, # torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled(),
}
with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):
images = pl_module.log_images(
batch, split=split, **self.log_images_kwargs
)
for k in images:
N = min(images[k].shape[0], self.max_images)
if not isheatmap(images[k]):
images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().float().cpu()
if self.clamp and not isheatmap(images[k]):
images[k] = torch.clamp(images[k], -1.0, 1.0)
self.log_local(
pl_module.logger.save_dir,
split,
images,
pl_module.global_step,
pl_module.current_epoch,
batch_idx,
pl_module=pl_module
if isinstance(pl_module.logger, WandbLogger)
else None,
)
if is_train:
pl_module.train()
def check_frequency(self, check_idx):
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
check_idx > 0 or self.log_first_step
):
try:
self.log_steps.pop(0)
except IndexError as e:
print(e)
pass
return True
return False
@rank_zero_only
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
self.log_img(pl_module, batch, batch_idx, split="train")
@rank_zero_only
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
if self.log_before_first_step and pl_module.global_step == 0:
print(f"{self.__class__.__name__}: logging before training")
self.log_img(pl_module, batch, batch_idx, split="train")
@rank_zero_only
# def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, **kwargs
):
if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split="val")
if hasattr(pl_module, "calibrate_grad_norm"):
if (
pl_module.calibrate_grad_norm and batch_idx % 25 == 0
) and batch_idx > 0:
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
@rank_zero_only
def init_wandb(save_dir, opt, config, group_name, name_str):
print(f"setting WANDB_DIR to {save_dir}")
os.makedirs(save_dir, exist_ok=True)
os.environ["WANDB_DIR"] = save_dir
if opt.debug:
wandb.init(project=opt.projectname, mode="offline", group=group_name)
else:
wandb.init(
project=opt.projectname,
config=config,
settings=wandb.Settings(code_dir="./sgm"),
group=group_name,
name=name_str,
)
if __name__ == "__main__":
# custom parser to specify config files, train, test and debug mode,
# postfix, resume.
# `--key value` arguments are interpreted as arguments to the trainer.
# `nested.key=value` arguments are interpreted as config parameters.
# configs are merged from left-to-right followed by command line parameters.
# model:
# base_learning_rate: float
# target: path to lightning module
# params:
# key: value
# data:
# target: main.DataModuleFromConfig
# params:
# batch_size: int
# wrap: bool
# train:
# target: path to train dataset
# params:
# key: value
# validation:
# target: path to validation dataset
# params:
# key: value
# test:
# target: path to test dataset
# params:
# key: value
# lightning: (optional, has sane defaults and can be specified on cmdline)
# trainer:
# additional arguments to trainer
# logger:
# logger to instantiate
# modelcheckpoint:
# modelcheckpoint to instantiate
# callbacks:
# callback1:
# target: importpath
# params:
# key: value
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
# add cwd for convenience and to make classes in this file available when
# running as `python main.py`
# (in particular `main.DataModuleFromConfig`)
sys.path.append(os.getcwd())
parser = get_parser()
opt, unknown = parser.parse_known_args()
if opt.name and opt.resume:
raise ValueError(
"-n/--name and -r/--resume cannot be specified both."
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint"
)
melk_ckpt_name = None
name = None
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split("/")
# idx = len(paths)-paths[::-1].index("logs")+1
# logdir = "/".join(paths[:idx])
logdir = "/".join(paths[:-2])
ckpt = opt.resume
_, melk_ckpt_name = get_checkpoint_name(logdir)
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt, melk_ckpt_name = get_checkpoint_name(logdir)
print("#" * 100)
print(f'Resuming from checkpoint "{ckpt}"')
print("#" * 100)
opt.resume_from_checkpoint = ckpt
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
opt.base = base_configs + opt.base
_tmp = logdir.split("/")
nowname = _tmp[-1]
else:
if opt.name:
name = "_" + opt.name
elif opt.base:
if opt.no_base_name:
name = ""
else:
if opt.legacy_naming:
cfg_fname = os.path.split(opt.base[0])[-1]
cfg_name = os.path.splitext(cfg_fname)[0]
else:
assert "configs" in os.path.split(opt.base[0])[0], os.path.split(
opt.base[0]
)[0]
cfg_path = os.path.split(opt.base[0])[0].split(os.sep)[
os.path.split(opt.base[0])[0].split(os.sep).index("configs")
+ 1 :
] # cut away the first one (we assert all configs are in "configs")
cfg_name = os.path.splitext(os.path.split(opt.base[0])[-1])[0]
cfg_name = "-".join(cfg_path) + f"-{cfg_name}"
name = "_" + cfg_name
else:
name = ""
if not opt.no_date:
nowname = now + name + opt.postfix
else:
nowname = name + opt.postfix
if nowname.startswith("_"):
nowname = nowname[1:]
logdir = os.path.join(opt.logdir, nowname)
print(f"LOGDIR: {logdir}")
ckptdir = os.path.join(logdir, "checkpoints")
cfgdir = os.path.join(logdir, "configs")
seed_everything(opt.seed, workers=True)
# move before model init, in case a torch.compile(...) is called somewhere
if opt.enable_tf32:
# pt_version = version.parse(torch.__version__)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
print(f"Enabling TF32 for PyTorch {torch.__version__}")
else:
print(f"Using default TF32 settings for PyTorch {torch.__version__}:")
print(
f"torch.backends.cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}"
)
print(f"torch.backends.cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}")
try:
# init and save configs
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
lightning_config = config.pop("lightning", OmegaConf.create())
# merge trainer cli with config
trainer_config = lightning_config.get("trainer", OmegaConf.create())
# default to gpu
trainer_config["accelerator"] = "gpu"
#
standard_args = default_trainer_args()
for k in standard_args:
if getattr(opt, k) != standard_args[k]:
trainer_config[k] = getattr(opt, k)
ckpt_resume_path = opt.resume_from_checkpoint
if not "devices" in trainer_config and trainer_config["accelerator"] != "gpu":
del trainer_config["accelerator"]
cpu = True
else:
gpuinfo = trainer_config["devices"]
print(f"Running on GPUs {gpuinfo}")
cpu = False
trainer_opt = argparse.Namespace(**trainer_config)
lightning_config.trainer = trainer_config
# model
model = instantiate_from_config(config.model)
# trainer and callbacks
trainer_kwargs = dict()
# default logger configs
default_logger_cfgs = {
"wandb": {
"target": "pytorch_lightning.loggers.WandbLogger",
"params": {
"name": nowname,
# "save_dir": logdir,
"offline": opt.debug,
"id": nowname,
"project": opt.projectname,
"log_model": False,
# "dir": logdir,
},
},
"csv": {
"target": "pytorch_lightning.loggers.CSVLogger",
"params": {
"name": "testtube", # hack for sbord fanatics
"save_dir": logdir,
},
},
}
default_logger_cfg = default_logger_cfgs["wandb" if opt.wandb else "csv"]
if opt.wandb:
# TODO change once leaving "swiffer" config directory
try:
group_name = nowname.split(now)[-1].split("-")[1]
except:
group_name = nowname
default_logger_cfg["params"]["group"] = group_name
init_wandb(
os.path.join(os.getcwd(), logdir),
opt=opt,
group_name=group_name,
config=config,
name_str=nowname,
)
if "logger" in lightning_config:
logger_cfg = lightning_config.logger
else:
logger_cfg = OmegaConf.create()
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
default_modelckpt_cfg = {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch:06}",
"verbose": True,
"save_last": True,
},
}
if hasattr(model, "monitor"):
print(f"Monitoring {model.monitor} as checkpoint metric.")
default_modelckpt_cfg["params"]["monitor"] = model.monitor
default_modelckpt_cfg["params"]["save_top_k"] = 3
if "modelcheckpoint" in lightning_config:
modelckpt_cfg = lightning_config.modelcheckpoint
else:
modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
# https://pytorch-lightning.readthedocs.io/en/stable/extensions/strategy.html
# default to ddp if not further specified
default_strategy_config = {"target": "pytorch_lightning.strategies.DDPStrategy"}
if "strategy" in lightning_config:
strategy_cfg = lightning_config.strategy
else:
strategy_cfg = OmegaConf.create()
default_strategy_config["params"] = {
"find_unused_parameters": False,
# "static_graph": True,
# "ddp_comm_hook": default.fp16_compress_hook # TODO: experiment with this, also for DDPSharded
}
strategy_cfg = OmegaConf.merge(default_strategy_config, strategy_cfg)
print(
f"strategy config: \n ++++++++++++++ \n {strategy_cfg} \n ++++++++++++++ "
)
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
# add callback which sets up log directory
default_callbacks_cfg = {
"setup_callback": {
"target": "main.SetupCallback",
"params": {
"resume": opt.resume,
"now": now,
"logdir": logdir,
"ckptdir": ckptdir,
"cfgdir": cfgdir,
"config": config,
"lightning_config": lightning_config,
"debug": opt.debug,
"ckpt_name": melk_ckpt_name,
},
},
"image_logger": {
"target": "main.ImageLogger",
"params": {"batch_frequency": 1000, "max_images": 4, "clamp": True},
},
"learning_rate_logger": {
"target": "pytorch_lightning.callbacks.LearningRateMonitor",
"params": {
"logging_interval": "step",
# "log_momentum": True
},
},
}
if version.parse(pl.__version__) >= version.parse("1.4.0"):
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
else:
callbacks_cfg = OmegaConf.create()
if "metrics_over_trainsteps_checkpoint" in callbacks_cfg:
print(
"Caution: Saving checkpoints every n train steps without deleting. This might require some free space."
)
default_metrics_over_trainsteps_ckpt_dict = {
"metrics_over_trainsteps_checkpoint": {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": os.path.join(ckptdir, "trainstep_checkpoints"),
"filename": "{epoch:06}-{step:09}",
"verbose": True,
"save_top_k": -1,
"every_n_train_steps": 10000,
"save_weights_only": True,
},
}
}
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
if "ignore_keys_callback" in callbacks_cfg and ckpt_resume_path is not None:
callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = ckpt_resume_path
elif "ignore_keys_callback" in callbacks_cfg:
del callbacks_cfg["ignore_keys_callback"]
trainer_kwargs["callbacks"] = [
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
]
if not "plugins" in trainer_kwargs:
trainer_kwargs["plugins"] = list()
# cmd line trainer args (which are in trainer_opt) have always priority over config-trainer-args (which are in trainer_kwargs)
trainer_opt = vars(trainer_opt)
trainer_kwargs = {
key: val for key, val in trainer_kwargs.items() if key not in trainer_opt
}
trainer = Trainer(**trainer_opt, **trainer_kwargs)
trainer.logdir = logdir ###
# data
data = instantiate_from_config(config.data)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
data.prepare_data()
# data.setup()
print("#### Data #####")
try:
for k in data.datasets:
print(
f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}"
)
except:
print("datasets not yet initialized.")
# configure learning rate
if "batch_size" in config.data.params:
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
else:
bs, base_lr = (
config.data.params.train.loader.batch_size,
config.model.base_learning_rate,
)
if not cpu:
ngpu = len(lightning_config.trainer.devices.strip(",").split(","))
else:
ngpu = 1
if "accumulate_grad_batches" in lightning_config.trainer:
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
else:
accumulate_grad_batches = 1
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
if opt.scale_lr:
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
print(
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr
)
)
else:
model.learning_rate = base_lr
print("++++ NOT USING LR SCALING ++++")
print(f"Setting learning rate to {model.learning_rate:.2e}")
# allow checkpointing via USR1
def melk(*args, **kwargs):
# run all checkpoint hooks
if trainer.global_rank == 0:
print("Summoning checkpoint.")
if melk_ckpt_name is None:
ckpt_path = os.path.join(ckptdir, "last.ckpt")
else:
ckpt_path = os.path.join(ckptdir, melk_ckpt_name)
trainer.save_checkpoint(ckpt_path)
def divein(*args, **kwargs):
if trainer.global_rank == 0:
import pudb
pudb.set_trace()
import signal
signal.signal(signal.SIGUSR1, melk)
signal.signal(signal.SIGUSR2, divein)
# run
if opt.train:
try:
trainer.fit(model, data, ckpt_path=ckpt_resume_path)
except Exception:
if not opt.debug:
melk()
raise
if not opt.no_test and not trainer.interrupted:
trainer.test(model, data)
except RuntimeError as err:
if MULTINODE_HACKS:
import requests
import datetime
import os
import socket
device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
hostname = socket.gethostname()
ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
resp = requests.get("http://169.254.169.254/latest/meta-data/instance-id")
print(
f"ERROR at {ts} on {hostname}/{resp.text} (CUDA_VISIBLE_DEVICES={device}): {type(err).__name__}: {err}",
flush=True,
)
raise err
except Exception:
if opt.debug and trainer.global_rank == 0:
try:
import pudb as debugger
except ImportError:
import pdb as debugger
debugger.post_mortem()
raise
finally:
# move newly created debug project to debug_runs
if opt.debug and not opt.resume and trainer.global_rank == 0:
dst, name = os.path.split(logdir)
dst = os.path.join(dst, "debug_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
os.rename(logdir, dst)
if opt.wandb:
wandb.finish()
# if trainer.global_rank == 0:
# print(trainer.profiler.summary())

41
requirements_pt13.txt Normal file
View File

@@ -0,0 +1,41 @@
omegaconf
einops
fire
tqdm
pillow
numpy
webdataset>=0.2.33
--extra-index-url https://download.pytorch.org/whl/cu117
torch==1.13.1+cu117
xformers==0.0.16
torchaudio==0.13.1
torchvision==0.14.1+cu117
torchmetrics
opencv-python==4.6.0.66
fairscale
pytorch-lightning==1.8.5
fsspec
kornia==0.6.9
matplotlib
natsort
tensorboardx==2.5.1
open-clip-torch
chardet
scipy
pandas
pudb
pyyaml
urllib3<1.27,>=1.25.4
streamlit>=0.73.1
timm
tokenizers==0.12.1
torchdata==0.5.1
transformers==4.19.1
onnx<=1.12.0
triton
wandb
invisible-watermark
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
-e git+https://github.com/openai/CLIP.git@main#egg=clip
-e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
-e .

41
requirements_pt2.txt Normal file
View File

@@ -0,0 +1,41 @@
omegaconf
einops
fire
tqdm
pillow
numpy
webdataset>=0.2.33
ninja
torch
matplotlib
torchaudio>=2.0.2
torchmetrics
torchvision>=0.15.2
opencv-python==4.6.0.66
fairscale
pytorch-lightning==2.0.1
fire
fsspec
kornia==0.6.9
natsort
open-clip-torch
chardet==5.1.0
tensorboardx==2.6
pandas
pudb
pyyaml
urllib3<1.27,>=1.25.4
scipy
streamlit>=0.73.1
timm
tokenizers==0.12.1
transformers==4.19.1
triton==2.0.0
torchdata==0.6.1
wandb
invisible-watermark
xformers
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
-e git+https://github.com/openai/CLIP.git@main#egg=clip
-e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
-e .

157
scripts/demo/detect.py Normal file
View File

@@ -0,0 +1,157 @@
import argparse
import cv2
import numpy as np
try:
from imwatermark import WatermarkDecoder
except ImportError as e:
try:
# Assume some of the other dependencies such as torch are not fulfilled
# import file without loading unnecessary libraries.
import importlib.util
import sys
spec = importlib.util.find_spec("imwatermark.maxDct")
assert spec is not None
maxDct = importlib.util.module_from_spec(spec)
sys.modules["maxDct"] = maxDct
spec.loader.exec_module(maxDct)
class WatermarkDecoder(object):
"""A minimal version of
https://github.com/ShieldMnt/invisible-watermark/blob/main/imwatermark/watermark.py
to only reconstruct bits using dwtDct"""
def __init__(self, wm_type="bytes", length=0):
assert wm_type == "bits", "Only bits defined in minimal import"
self._wmType = wm_type
self._wmLen = length
def reconstruct(self, bits):
if len(bits) != self._wmLen:
raise RuntimeError("bits are not matched with watermark length")
return bits
def decode(self, cv2Image, method="dwtDct", **configs):
(r, c, channels) = cv2Image.shape
if r * c < 256 * 256:
raise RuntimeError("image too small, should be larger than 256x256")
bits = []
assert method == "dwtDct"
embed = maxDct.EmbedMaxDct(watermarks=[], wmLen=self._wmLen, **configs)
bits = embed.decode(cv2Image)
return self.reconstruct(bits)
except:
raise e
# A fixed 48-bit message that was choosen at random
# WATERMARK_MESSAGE = 0xB3EC907BB19E
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
MATCH_VALUES = [
[27, "No watermark detected"],
[33, "Partial watermark match. Cannot determine with certainty."],
[
35,
(
"Likely watermarked. In our test 0.02% of real images were "
'falsely detected as "Likely watermarked"'
),
],
[
49,
(
"Very likely watermarked. In our test no real images were "
'falsely detected as "Very likely watermarked"'
),
],
]
class GetWatermarkMatch:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(self.watermark)
self.decoder = WatermarkDecoder("bits", self.num_bits)
def __call__(self, x: np.ndarray) -> np.ndarray:
"""
Detects the number of matching bits the predefined watermark with one
or multiple images. Images should be in cv2 format, e.g. h x w x c.
Args:
x: ([B], h w, c) in range [0, 255]
Returns:
number of matched bits ([B],)
"""
squeeze = len(x.shape) == 3
if squeeze:
x = x[None, ...]
x = np.flip(x, axis=-1)
bs = x.shape[0]
detected = np.empty((bs, self.num_bits), dtype=bool)
for k in range(bs):
detected[k] = self.decoder.decode(x[k], "dwtDct")
result = np.sum(detected == self.watermark, axis=-1)
if squeeze:
return result[0]
else:
return result
get_watermark_match = GetWatermarkMatch(WATERMARK_BITS)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"filename",
nargs="+",
type=str,
help="Image files to check for watermarks",
)
opts = parser.parse_args()
print(
"""
This script tries to detect watermarked images. Please be aware of
the following:
- As the watermark is supposed to be invisible, there is the risk that
watermarked images may not be detected.
- To maximize the chance of detection make sure that the image has the same
dimensions as when the watermark was applied (most likely 1024x1024
or 512x512).
- Specific image manipulation may drastically decrease the chance that
watermarks can be detected.
- There is also the chance that an image has the characteristics of the
watermark by chance.
- The watermark script is public, anybody may watermark any images, and
could therefore claim it to be generated.
- All numbers below are based on a test using 10,000 images without any
modifications after applying the watermark.
"""
)
for fn in opts.filename:
image = cv2.imread(fn)
if image is None:
print(f"Couldn't read {fn}. Skipping")
continue
num_bits = get_watermark_match(image)
k = 0
while num_bits > MATCH_VALUES[k][0]:
k += 1
print(
f"{fn}: {MATCH_VALUES[k][1]}",
f"Bits that matched the watermark {num_bits} from {len(WATERMARK_BITS)}\n",
sep="\n\t",
)

328
scripts/demo/sampling.py Normal file
View File

@@ -0,0 +1,328 @@
from pytorch_lightning import seed_everything
from scripts.demo.streamlit_helpers import *
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
SAVE_PATH = "outputs/demo/txt2img/"
SD_XL_BASE_RATIOS = {
"0.5": (704, 1408),
"0.52": (704, 1344),
"0.57": (768, 1344),
"0.6": (768, 1280),
"0.68": (832, 1216),
"0.72": (832, 1152),
"0.78": (896, 1152),
"0.82": (896, 1088),
"0.88": (960, 1088),
"0.94": (960, 1024),
"1.0": (1024, 1024),
"1.07": (1024, 960),
"1.13": (1088, 960),
"1.21": (1088, 896),
"1.29": (1152, 896),
"1.38": (1152, 832),
"1.46": (1216, 832),
"1.67": (1280, 768),
"1.75": (1344, 768),
"1.91": (1344, 704),
"2.0": (1408, 704),
"2.09": (1472, 704),
"2.4": (1536, 640),
"2.5": (1600, 640),
"2.89": (1664, 576),
"3.0": (1728, 576),
}
VERSION2SPECS = {
"SD-XL base": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": False,
"config": "configs/inference/sd_xl_base.yaml",
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
"is_guided": True,
},
"sd-2.1": {
"H": 512,
"W": 512,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_2_1.yaml",
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
"is_guided": True,
},
"sd-2.1-768": {
"H": 768,
"W": 768,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_2_1_768.yaml",
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
},
"SDXL-Refiner": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_xl_refiner.yaml",
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
"is_guided": True,
},
}
def load_img(display=True, key=None, device="cuda"):
image = get_interactive_image(key=key)
if image is None:
return None
if display:
st.image(image)
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = image.resize((width, height))
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
return image.to(device)
def run_txt2img(
state, version, version_dict, is_legacy=False, return_latents=False, filter=None
):
if version == "SD-XL base":
ratio = st.sidebar.selectbox("Ratio:", list(SD_XL_BASE_RATIOS.keys()), 10)
W, H = SD_XL_BASE_RATIOS[ratio]
else:
H = st.sidebar.number_input(
"H", value=version_dict["H"], min_value=64, max_value=2048
)
W = st.sidebar.number_input(
"W", value=version_dict["W"], min_value=64, max_value=2048
)
C = version_dict["C"]
F = version_dict["f"]
init_dict = {
"orig_width": W,
"orig_height": H,
"target_width": W,
"target_height": H,
}
value_dict = init_embedder_options(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
init_dict,
prompt=prompt,
negative_prompt=negative_prompt,
)
num_rows, num_cols, sampler = init_sampling(
use_identity_guider=not version_dict["is_guided"]
)
num_samples = num_rows * num_cols
if st.button("Sample"):
st.write(f"**Model I:** {version}")
out = do_sample(
state["model"],
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
return_latents=return_latents,
filter=filter,
)
return out
def run_img2img(
state, version_dict, is_legacy=False, return_latents=False, filter=None
):
img = load_img()
if img is None:
return None
H, W = img.shape[2], img.shape[3]
init_dict = {
"orig_width": W,
"orig_height": H,
"target_width": W,
"target_height": H,
}
value_dict = init_embedder_options(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
init_dict,
)
strength = st.number_input(
"**Img2Img Strength**", value=0.5, min_value=0.0, max_value=1.0
)
num_rows, num_cols, sampler = init_sampling(
img2img_strength=strength,
use_identity_guider=not version_dict["is_guided"],
)
num_samples = num_rows * num_cols
if st.button("Sample"):
out = do_img2img(
repeat(img, "1 ... -> n ...", n=num_samples),
state["model"],
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
return_latents=return_latents,
filter=filter,
)
return out
def apply_refiner(
input,
state,
sampler,
num_samples,
prompt,
negative_prompt,
filter=None,
):
init_dict = {
"orig_width": input.shape[3] * 8,
"orig_height": input.shape[2] * 8,
"target_width": input.shape[3] * 8,
"target_height": input.shape[2] * 8,
}
value_dict = init_dict
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
value_dict["crop_coords_top"] = 0
value_dict["crop_coords_left"] = 0
value_dict["aesthetic_score"] = 6.0
value_dict["negative_aesthetic_score"] = 2.5
st.warning(f"refiner input shape: {input.shape}")
samples = do_img2img(
input,
state["model"],
sampler,
value_dict,
num_samples,
skip_encode=True,
filter=filter,
)
return samples
if __name__ == "__main__":
st.title("Stable Diffusion")
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
version_dict = VERSION2SPECS[version]
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
st.write("__________________________")
if version == "SD-XL base":
add_pipeline = st.checkbox("Load SDXL-Refiner?", False)
st.write("__________________________")
else:
add_pipeline = False
filter = DeepFloydDataFiltering(verbose=False)
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
seed_everything(seed)
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
state = init_st(version_dict)
if state["msg"]:
st.info(state["msg"])
model = state["model"]
is_legacy = version_dict["is_legacy"]
prompt = st.text_input(
"prompt",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
)
if is_legacy:
negative_prompt = st.text_input("negative prompt", "")
else:
negative_prompt = "" # which is unused
if add_pipeline:
st.write("__________________________")
version2 = "SDXL-Refiner"
st.warning(
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
)
st.write("**Refiner Options:**")
version_dict2 = VERSION2SPECS[version2]
state2 = init_st(version_dict2)
st.info(state2["msg"])
stage2strength = st.number_input(
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
)
sampler2 = init_sampling(
key=2,
img2img_strength=stage2strength,
use_identity_guider=not version_dict["is_guided"],
get_num_samples=False,
)
st.write("__________________________")
if mode == "txt2img":
out = run_txt2img(
state,
version,
version_dict,
is_legacy=is_legacy,
return_latents=add_pipeline,
filter=filter,
)
elif mode == "img2img":
out = run_img2img(
state,
version_dict,
is_legacy=is_legacy,
return_latents=add_pipeline,
filter=filter,
)
else:
raise ValueError(f"unknown mode {mode}")
if isinstance(out, (tuple, list)):
samples, samples_z = out
else:
samples = out
if add_pipeline:
st.write("**Running Refinement Stage**")
samples = apply_refiner(
samples_z,
state2,
sampler2,
samples_z.shape[0],
prompt=prompt,
negative_prompt=negative_prompt if is_legacy else "",
filter=filter,
)
if save_locally and samples is not None:
perform_save_locally(save_path, samples)

View File

@@ -0,0 +1,668 @@
import os
from typing import Union, List
import math
import numpy as np
import streamlit as st
import torch
from PIL import Image
from einops import rearrange, repeat
from imwatermark import WatermarkEncoder
from omegaconf import OmegaConf, ListConfig
from torch import autocast
from torchvision import transforms
from torchvision.utils import make_grid
from safetensors.torch import load_file as load_safetensors
from sgm.modules.diffusionmodules.sampling import (
EulerEDMSampler,
HeunEDMSampler,
EulerAncestralSampler,
DPMPP2SAncestralSampler,
DPMPP2MSampler,
LinearMultistepSampler,
)
from sgm.util import append_dims
from sgm.util import instantiate_from_config
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor):
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, C, H, W) in range [0, 1]
Returns:
same as input but watermarked
"""
# watermarking libary expects input as cv2 format
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(
rearrange(image_np, "(n b) h w c -> n b c h w", n=n)
).to(image.device)
image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
return image
# A fixed 48-bit message that was choosen at random
# WATERMARK_MESSAGE = 0xB3EC907BB19E
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
@st.cache_resource()
def init_st(version_dict, load_ckpt=True):
state = dict()
if not "model" in state:
config = version_dict["config"]
ckpt = version_dict["ckpt"]
config = OmegaConf.load(config)
model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
state["msg"] = msg
state["model"] = model
state["ckpt"] = ckpt if load_ckpt else None
state["config"] = config
return state
def load_model_from_config(config, ckpt=None, verbose=True):
model = instantiate_from_config(config.model)
if ckpt is not None:
print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
global_step = pl_sd["global_step"]
st.info(f"loaded ckpt from global step {global_step}")
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
else:
raise NotImplementedError
msg = None
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
else:
msg = None
model.cuda()
model.eval()
return model, msg
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
# Hardcoded demo settings; might undergo some changes in the future
value_dict = {}
for key in keys:
if key == "txt":
if prompt is None:
prompt = st.text_input(
"Prompt", "A professional photograph of an astronaut riding a pig"
)
if negative_prompt is None:
negative_prompt = st.text_input("Negative prompt", "")
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
if key == "original_size_as_tuple":
orig_width = st.number_input(
"orig_width",
value=init_dict["orig_width"],
min_value=16,
)
orig_height = st.number_input(
"orig_height",
value=init_dict["orig_height"],
min_value=16,
)
value_dict["orig_width"] = orig_width
value_dict["orig_height"] = orig_height
if key == "crop_coords_top_left":
crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0)
crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0)
value_dict["crop_coords_top"] = crop_coord_top
value_dict["crop_coords_left"] = crop_coord_left
if key == "aesthetic_score":
value_dict["aesthetic_score"] = 6.0
value_dict["negative_aesthetic_score"] = 2.5
if key == "target_size_as_tuple":
target_width = st.number_input(
"target_width",
value=init_dict["target_width"],
min_value=16,
)
target_height = st.number_input(
"target_height",
value=init_dict["target_height"],
min_value=16,
)
value_dict["target_width"] = target_width
value_dict["target_height"] = target_height
return value_dict
def perform_save_locally(save_path, samples):
os.makedirs(os.path.join(save_path), exist_ok=True)
base_count = len(os.listdir(os.path.join(save_path)))
samples = embed_watemark(samples)
for sample in samples:
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save(
os.path.join(save_path, f"{base_count:09}.png")
)
base_count += 1
def init_save_locally(_dir, init_value: bool = False):
save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
if save_locally:
save_path = st.text_input("Save path", value=os.path.join(_dir, "samples"))
else:
save_path = None
return save_locally, save_path
class Img2ImgDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 1.0):
self.discretization = discretization
self.strength = strength
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def get_guider(key):
guider = st.sidebar.selectbox(
f"Discretization #{key}",
[
"VanillaCFG",
"IdentityGuider",
],
)
if guider == "IdentityGuider":
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
elif guider == "VanillaCFG":
scale = st.number_input(
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
)
thresholder = st.sidebar.selectbox(
f"Thresholder #{key}",
[
"None",
],
)
if thresholder == "None":
dyn_thresh_config = {
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
}
else:
raise NotImplementedError
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
}
else:
raise NotImplementedError
return guider_config
def init_sampling(
key=1, img2img_strength=1.0, use_identity_guider=False, get_num_samples=True
):
if get_num_samples:
num_rows = 1
num_cols = st.number_input(
f"num cols #{key}", value=2, min_value=1, max_value=10
)
steps = st.sidebar.number_input(
f"steps #{key}", value=50, min_value=1, max_value=1000
)
sampler = st.sidebar.selectbox(
f"Sampler #{key}",
[
"EulerEDMSampler",
"HeunEDMSampler",
"EulerAncestralSampler",
"DPMPP2SAncestralSampler",
"DPMPP2MSampler",
"LinearMultistepSampler",
],
0,
)
discretization = st.sidebar.selectbox(
f"Discretization #{key}",
[
"LegacyDDPMDiscretization",
"EDMDiscretization",
],
)
discretization_config = get_discretization(discretization, key=key)
guider_config = get_guider(key=key)
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
if img2img_strength < 1.0:
st.warning(
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
)
sampler.discretization = Img2ImgDiscretizationWrapper(
sampler.discretization, strength=img2img_strength
)
if get_num_samples:
return num_rows, num_cols, sampler
return sampler
def get_discretization(discretization, key=1):
if discretization == "LegacyDDPMDiscretization":
use_new_range = st.checkbox(f"Start from highest noise level? #{key}", False)
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
"params": {"legacy_range": not use_new_range},
}
elif discretization == "EDMDiscretization":
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146
rho = st.number_input(f"rho #{key}", value=3.0)
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
"params": {
"sigma_min": sigma_min,
"sigma_max": sigma_max,
"rho": rho,
},
}
return discretization_config
def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
if sampler_name == "EulerEDMSampler":
sampler = EulerEDMSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=s_churn,
s_tmin=s_tmin,
s_tmax=s_tmax,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "HeunEDMSampler":
sampler = HeunEDMSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=s_churn,
s_tmin=s_tmin,
s_tmax=s_tmax,
s_noise=s_noise,
verbose=True,
)
elif (
sampler_name == "EulerAncestralSampler"
or sampler_name == "DPMPP2SAncestralSampler"
):
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
if sampler_name == "EulerAncestralSampler":
sampler = EulerAncestralSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=eta,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "DPMPP2SAncestralSampler":
sampler = DPMPP2SAncestralSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=eta,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "DPMPP2MSampler":
sampler = DPMPP2MSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
verbose=True,
)
elif sampler_name == "LinearMultistepSampler":
order = st.sidebar.number_input("order", value=4, min_value=1)
sampler = LinearMultistepSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
order=order,
verbose=True,
)
else:
raise ValueError(f"unknown sampler {sampler_name}!")
return sampler
def get_interactive_image(key=None) -> Image.Image:
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
if image is not None:
image = Image.open(image)
if not image.mode == "RGB":
image = image.convert("RGB")
return image
def load_img(display=True, key=None):
image = get_interactive_image(key=key)
if image is None:
return None
if display:
st.image(image)
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Lambda(lambda x: x * 2.0 - 1.0),
]
)
img = transform(image)[None, ...]
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
return img
def get_init_img(batch_size=1, key=None):
init_image = load_img(key=key).cuda()
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
return init_image
def do_sample(
model,
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings: List = None,
batch2model_input: List = None,
return_latents=False,
filter=None,
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
if batch2model_input is None:
batch2model_input = []
st.text("Sampling")
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
num_samples = [num_samples]
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
)
additional_model_inputs = {}
for k in batch2model_input:
additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to("cuda")
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
grid = torch.stack([samples])
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
# Hardcoded demo setups; might undergo some changes in the future
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor(
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
)
.to(device)
.repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device)
.repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(*N, 1)
)
else:
batch[key] = value_dict[key]
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
@torch.no_grad()
def do_img2img(
img,
model,
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=[],
additional_kwargs={},
offset_noise_level: int = 0.0,
return_latents=False,
skip_encode=False,
filter=None,
):
st.text("Sampling")
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[num_samples],
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
for k in additional_kwargs:
c[k] = uc[k] = additional_kwargs[k]
if skip_encode:
z = img
else:
z = model.encode_first_stage(img)
noise = torch.randn_like(z)
sigmas = sampler.discretization(sampler.num_steps)
sigma = sigmas[0]
st.info(f"all sigmas: {sigmas}")
st.info(f"noising sigma: {sigma}")
if offset_noise_level > 0.0:
noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim
)
noised_z = z + noise * append_dims(sigma, z.ndim)
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c)
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
grid = embed_watemark(torch.stack([samples]))
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples

View File

@@ -0,0 +1,104 @@
import os
import torch
import numpy as np
import torchvision.transforms as T
from PIL import Image
import clip
RESOURCES_ROOT = "scripts/util/detection/"
def predict_proba(X, weights, biases):
logits = X @ weights.T + biases
proba = np.where(
logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits))
)
return proba.T
def load_model_weights(path: str):
model_weights = np.load(path)
return model_weights["weights"], model_weights["biases"]
def clip_process_images(images: torch.Tensor) -> torch.Tensor:
min_size = min(images.shape[-2:])
return T.Compose(
[
T.CenterCrop(min_size), # TODO: this might affect the watermark, check this
T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
T.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)(images)
class DeepFloydDataFiltering(object):
def __init__(self, verbose: bool = False):
super().__init__()
self.verbose = verbose
self.clip_model, _ = clip.load("ViT-L/14", device="cpu")
self.clip_model.eval()
self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
os.path.join(RESOURCES_ROOT, "w_head_v1.npz")
)
self.cpu_p_weights, self.cpu_p_biases = load_model_weights(
os.path.join(RESOURCES_ROOT, "p_head_v1.npz")
)
self.w_threshold, self.p_threshold = 0.5, 0.5
@torch.inference_mode()
def __call__(self, images: torch.Tensor) -> torch.Tensor:
imgs = clip_process_images(images)
image_features = self.clip_model.encode_image(imgs.to("cpu"))
image_features = image_features.detach().cpu().numpy().astype(np.float16)
p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)
print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None
query = p_pred > self.p_threshold
if query.sum() > 0:
print(f"Hit for p_threshold: {p_pred}") if self.verbose else None
images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
query = w_pred > self.w_threshold
if query.sum() > 0:
print(f"Hit for w_threshold: {w_pred}") if self.verbose else None
images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
return images
def load_img(path: str) -> torch.Tensor:
image = Image.open(path)
if not image.mode == "RGB":
image = image.convert("RGB")
image_transforms = T.Compose(
[
T.ToTensor(),
]
)
return image_transforms(image)[None, ...]
def test(root):
from einops import rearrange
filter = DeepFloydDataFiltering(verbose=True)
for p in os.listdir((root)):
print(f"running on {p}...")
img = load_img(os.path.join(root, p))
filtered_img = filter(img)
filtered_img = rearrange(
255.0 * (filtered_img.numpy())[0], "c h w -> h w c"
).astype(np.uint8)
Image.fromarray(filtered_img).save(
os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg")
)
if __name__ == "__main__":
import fire
fire.Fire(test)
print("done.")

Binary file not shown.

Binary file not shown.

13
setup.py Normal file
View File

@@ -0,0 +1,13 @@
from setuptools import find_packages, setup
setup(
name="sgm",
version="0.0.1",
packages=find_packages(),
python_requires=">=3.8",
py_modules=["sgm"],
description="Stability Generative Models",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
url="https://github.com/Stability-AI/generative-models",
)

3
sgm/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .data import StableDataModuleFromConfig
from .models import AutoencodingEngine, DiffusionEngine
from .util import instantiate_from_config

1
sgm/data/__init__.py Normal file
View File

@@ -0,0 +1 @@
from .dataset import StableDataModuleFromConfig

67
sgm/data/cifar10.py Normal file
View File

@@ -0,0 +1,67 @@
import torchvision
import pytorch_lightning as pl
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
class CIFAR10DataDictWrapper(Dataset):
def __init__(self, dset):
super().__init__()
self.dset = dset
def __getitem__(self, i):
x, y = self.dset[i]
return {"jpg": x, "cls": y}
def __len__(self):
return len(self.dset)
class CIFAR10Loader(pl.LightningDataModule):
def __init__(self, batch_size, num_workers=0, shuffle=True):
super().__init__()
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
)
self.batch_size = batch_size
self.num_workers = num_workers
self.shuffle = shuffle
self.train_dataset = CIFAR10DataDictWrapper(
torchvision.datasets.CIFAR10(
root=".data/", train=True, download=True, transform=transform
)
)
self.test_dataset = CIFAR10DataDictWrapper(
torchvision.datasets.CIFAR10(
root=".data/", train=False, download=True, transform=transform
)
)
def prepare_data(self):
pass
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
)
def val_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
)

80
sgm/data/dataset.py Normal file
View File

@@ -0,0 +1,80 @@
from typing import Optional
import torchdata.datapipes.iter
import webdataset as wds
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule
try:
from sdata import create_dataset, create_dummy_dataset, create_loader
except ImportError as e:
print("#" * 100)
print("Datasets not yet available")
print("to enable, we need to add stable-datasets as a submodule")
print("please use ``git submodule update --init --recursive``")
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
print("#" * 100)
exit(1)
class StableDataModuleFromConfig(LightningDataModule):
def __init__(
self,
train: DictConfig,
validation: Optional[DictConfig] = None,
test: Optional[DictConfig] = None,
skip_val_loader: bool = False,
dummy: bool = False,
):
super().__init__()
self.train_config = train
assert (
"datapipeline" in self.train_config and "loader" in self.train_config
), "train config requires the fields `datapipeline` and `loader`"
self.val_config = validation
if not skip_val_loader:
if self.val_config is not None:
assert (
"datapipeline" in self.val_config and "loader" in self.val_config
), "validation config requires the fields `datapipeline` and `loader`"
else:
print(
"Warning: No Validation datapipeline defined, using that one from training"
)
self.val_config = train
self.test_config = test
if self.test_config is not None:
assert (
"datapipeline" in self.test_config and "loader" in self.test_config
), "test config requires the fields `datapipeline` and `loader`"
self.dummy = dummy
if self.dummy:
print("#" * 100)
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
print("#" * 100)
def setup(self, stage: str) -> None:
print("Preparing datasets")
if self.dummy:
data_fn = create_dummy_dataset
else:
data_fn = create_dataset
self.train_datapipeline = data_fn(**self.train_config.datapipeline)
if self.val_config:
self.val_datapipeline = data_fn(**self.val_config.datapipeline)
if self.test_config:
self.test_datapipeline = data_fn(**self.test_config.datapipeline)
def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
loader = create_loader(self.train_datapipeline, **self.train_config.loader)
return loader
def val_dataloader(self) -> wds.DataPipeline:
return create_loader(self.val_datapipeline, **self.val_config.loader)
def test_dataloader(self) -> wds.DataPipeline:
return create_loader(self.test_datapipeline, **self.test_config.loader)

85
sgm/data/mnist.py Normal file
View File

@@ -0,0 +1,85 @@
import torchvision
import pytorch_lightning as pl
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
class MNISTDataDictWrapper(Dataset):
def __init__(self, dset):
super().__init__()
self.dset = dset
def __getitem__(self, i):
x, y = self.dset[i]
return {"jpg": x, "cls": y}
def __len__(self):
return len(self.dset)
class MNISTLoader(pl.LightningDataModule):
def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):
super().__init__()
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
)
self.batch_size = batch_size
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
self.shuffle = shuffle
self.train_dataset = MNISTDataDictWrapper(
torchvision.datasets.MNIST(
root=".data/", train=True, download=True, transform=transform
)
)
self.test_dataset = MNISTDataDictWrapper(
torchvision.datasets.MNIST(
root=".data/", train=False, download=True, transform=transform
)
)
def prepare_data(self):
pass
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
)
def val_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
)
if __name__ == "__main__":
dset = MNISTDataDictWrapper(
torchvision.datasets.MNIST(
root=".data/",
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
),
)
)
ex = dset[0]

135
sgm/lr_scheduler.py Normal file
View File

@@ -0,0 +1,135 @@
import numpy as np
class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(
self,
warm_up_steps,
lr_min,
lr_max,
lr_start,
max_decay_steps,
verbosity_interval=0,
):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.0
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (
self.lr_max - self.lr_start
) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (
self.lr_max_decay_steps - self.lr_warm_up_steps
)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
1 + np.cos(t * np.pi)
)
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaWarmUpCosineScheduler2:
"""
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
):
assert (
len(warm_up_steps)
== len(f_min)
== len(f_max)
== len(f_start)
== len(cycle_lengths)
)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.0
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
interval = 0
for cl in self.cum_cycles[1:]:
if n <= cl:
return interval
interval += 1
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
)
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi)
)
self.last_f = f
return f
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
self.cycle_lengths[cycle] - n
) / (self.cycle_lengths[cycle])
self.last_f = f
return f

2
sgm/models/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
from .autoencoder import AutoencodingEngine
from .diffusion import DiffusionEngine

335
sgm/models/autoencoder.py Normal file
View File

@@ -0,0 +1,335 @@
import re
from abc import abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Tuple, Union
import pytorch_lightning as pl
import torch
from omegaconf import ListConfig
from packaging import version
from safetensors.torch import load_file as load_safetensors
from ..modules.diffusionmodules.model import Decoder, Encoder
from ..modules.distributions.distributions import DiagonalGaussianDistribution
from ..modules.ema import LitEma
from ..util import default, get_obj_from_str, instantiate_from_config
class AbstractAutoencoder(pl.LightningModule):
"""
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
unCLIP models, etc. Hence, it is fairly general, and specific features
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
"""
def __init__(
self,
ema_decay: Union[None, float] = None,
monitor: Union[None, str] = None,
input_key: str = "jpg",
ckpt_path: Union[None, str] = None,
ignore_keys: Union[Tuple, list, ListConfig] = (),
):
super().__init__()
self.input_key = input_key
self.use_ema = ema_decay is not None
if monitor is not None:
self.monitor = monitor
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
if version.parse(torch.__version__) >= version.parse("2.0.0"):
self.automatic_optimization = False
def init_from_ckpt(
self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
) -> None:
if path.endswith("ckpt"):
sd = torch.load(path, map_location="cpu")["state_dict"]
elif path.endswith("safetensors"):
sd = load_safetensors(path)
else:
raise NotImplementedError
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if re.match(ik, k):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
@abstractmethod
def get_input(self, batch) -> Any:
raise NotImplementedError()
def on_train_batch_end(self, *args, **kwargs):
# for EMA computation
if self.use_ema:
self.model_ema(self)
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
@abstractmethod
def encode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("encode()-method of abstract base class called")
@abstractmethod
def decode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg):
print(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
def configure_optimizers(self) -> Any:
raise NotImplementedError()
class AutoencodingEngine(AbstractAutoencoder):
"""
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
(we also restore them explicitly as special cases for legacy reasons).
Regularizations such as KL or VQ are moved to the regularizer class.
"""
def __init__(
self,
*args,
encoder_config: Dict,
decoder_config: Dict,
loss_config: Dict,
regularizer_config: Dict,
optimizer_config: Union[Dict, None] = None,
lr_g_factor: float = 1.0,
**kwargs,
):
super().__init__(*args, **kwargs)
# todo: add options to freeze encoder/decoder
self.encoder = instantiate_from_config(encoder_config)
self.decoder = instantiate_from_config(decoder_config)
self.loss = instantiate_from_config(loss_config)
self.regularization = instantiate_from_config(regularizer_config)
self.optimizer_config = default(
optimizer_config, {"target": "torch.optim.Adam"}
)
self.lr_g_factor = lr_g_factor
def get_input(self, batch: Dict) -> torch.Tensor:
# assuming unified data format, dataloader returns a dict.
# image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
return batch[self.input_key]
def get_autoencoder_params(self) -> list:
params = (
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.regularization.get_trainable_parameters())
+ list(self.loss.get_trainable_autoencoder_parameters())
)
return params
def get_discriminator_params(self) -> list:
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
return params
def get_last_layer(self):
return self.decoder.get_last_layer()
def encode(self, x: Any, return_reg_log: bool = False) -> Any:
z = self.encoder(x)
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: Any) -> torch.Tensor:
x = self.decoder(z)
return x
def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z)
return z, dec, reg_log
def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
x = self.get_input(batch)
z, xrec, regularization_log = self(x)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(
regularization_log,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log_dict(
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(
regularization_log,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
)
return discloss
def validation_step(self, batch, batch_idx) -> Dict:
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
log_dict.update(log_dict_ema)
return log_dict
def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
x = self.get_input(batch)
z, xrec, regularization_log = self(x)
aeloss, log_dict_ae = self.loss(
regularization_log,
x,
xrec,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + postfix,
)
discloss, log_dict_disc = self.loss(
regularization_log,
x,
xrec,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + postfix,
)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
log_dict_ae.update(log_dict_disc)
self.log_dict(log_dict_ae)
return log_dict_ae
def configure_optimizers(self) -> Any:
ae_params = self.get_autoencoder_params()
disc_params = self.get_discriminator_params()
opt_ae = self.instantiate_optimizer_from_config(
ae_params,
default(self.lr_g_factor, 1.0) * self.learning_rate,
self.optimizer_config,
)
opt_disc = self.instantiate_optimizer_from_config(
disc_params, self.learning_rate, self.optimizer_config
)
return [opt_ae, opt_disc], []
@torch.no_grad()
def log_images(self, batch: Dict, **kwargs) -> Dict:
log = dict()
x = self.get_input(batch)
_, xrec, _ = self(x)
log["inputs"] = x
log["reconstructions"] = xrec
with self.ema_scope():
_, xrec_ema, _ = self(x)
log["reconstructions_ema"] = xrec_ema
return log
class AutoencoderKL(AutoencodingEngine):
def __init__(self, embed_dim: int, **kwargs):
ddconfig = kwargs.pop("ddconfig")
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", ())
super().__init__(
encoder_config={"target": "torch.nn.Identity"},
decoder_config={"target": "torch.nn.Identity"},
regularizer_config={"target": "torch.nn.Identity"},
loss_config=kwargs.pop("lossconfig"),
**kwargs,
)
assert ddconfig["double_z"]
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def encode(self, x):
assert (
not self.training
), f"{self.__class__.__name__} only supports inference currently"
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z, **decoder_kwargs):
z = self.post_quant_conv(z)
dec = self.decoder(z, **decoder_kwargs)
return dec
class AutoencoderKLInferenceWrapper(AutoencoderKL):
def encode(self, x):
return super().encode(x).sample()
class IdentityFirstStage(AbstractAutoencoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def get_input(self, x: Any) -> Any:
return x
def encode(self, x: Any, *args, **kwargs) -> Any:
return x
def decode(self, x: Any, *args, **kwargs) -> Any:
return x

324
sgm/models/diffusion.py Normal file
View File

@@ -0,0 +1,324 @@
from contextlib import contextmanager
from typing import Any, Dict, List, Tuple, Union
import pytorch_lightning as pl
import torch
from omegaconf import ListConfig, OmegaConf
from safetensors.torch import load_file as load_safetensors
from torch.optim.lr_scheduler import LambdaLR
from ..modules import UNCONDITIONAL_CONFIG
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from ..modules.ema import LitEma
from ..util import (
default,
disabled_train,
get_obj_from_str,
instantiate_from_config,
log_txt_as_img,
)
class DiffusionEngine(pl.LightningModule):
def __init__(
self,
network_config,
denoiser_config,
first_stage_config,
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
network_wrapper: Union[None, str] = None,
ckpt_path: Union[None, str] = None,
use_ema: bool = False,
ema_decay_rate: float = 0.9999,
scale_factor: float = 1.0,
disable_first_stage_autocast=False,
input_key: str = "jpg",
log_keys: Union[List, None] = None,
no_cond_log: bool = False,
compile_model: bool = False,
):
super().__init__()
self.log_keys = log_keys
self.input_key = input_key
self.optimizer_config = default(
optimizer_config, {"target": "torch.optim.AdamW"}
)
model = instantiate_from_config(network_config)
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
model, compile_model=compile_model
)
self.denoiser = instantiate_from_config(denoiser_config)
self.sampler = (
instantiate_from_config(sampler_config)
if sampler_config is not None
else None
)
self.conditioner = instantiate_from_config(
default(conditioner_config, UNCONDITIONAL_CONFIG)
)
self.scheduler_config = scheduler_config
self._init_first_stage(first_stage_config)
self.loss_fn = (
instantiate_from_config(loss_fn_config)
if loss_fn_config is not None
else None
)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
self.no_cond_log = no_cond_log
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path)
def init_from_ckpt(
self,
path: str,
) -> None:
if path.endswith("ckpt"):
sd = torch.load(path, map_location="cpu")["state_dict"]
elif path.endswith("safetensors"):
sd = load_safetensors(path)
else:
raise NotImplementedError
missing, unexpected = self.load_state_dict(sd, strict=False)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def _init_first_stage(self, config):
model = instantiate_from_config(config).eval()
model.train = disabled_train
for param in model.parameters():
param.requires_grad = False
self.first_stage_model = model
def get_input(self, batch):
# assuming unified data format, dataloader returns a dict.
# image tensors should be scaled to -1 ... 1 and in bchw format
return batch[self.input_key]
@torch.no_grad()
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
out = self.first_stage_model.decode(z)
return out
@torch.no_grad()
def encode_first_stage(self, x):
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
z = self.first_stage_model.encode(x)
z = self.scale_factor * z
return z
def forward(self, x, batch):
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
loss_mean = loss.mean()
loss_dict = {"loss": loss_mean}
return loss_mean, loss_dict
def shared_step(self, batch: Dict) -> Any:
x = self.get_input(batch)
x = self.encode_first_stage(x)
batch["global_step"] = self.global_step
loss, loss_dict = self(x, batch)
return loss, loss_dict
def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)
self.log_dict(
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
)
self.log(
"global_step",
self.global_step,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
)
if self.scheduler_config is not None:
lr = self.optimizers().param_groups[0]["lr"]
self.log(
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
)
return loss
def on_train_start(self, *args, **kwargs):
if self.sampler is None or self.loss_fn is None:
raise ValueError("Sampler and loss function need to be set for training.")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self.model)
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def instantiate_optimizer_from_config(self, params, lr, cfg):
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
for embedder in self.conditioner.embedders:
if embedder.is_trainable:
params = params + list(embedder.parameters())
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
}
]
return [opt], scheduler
return opt
@torch.no_grad()
def sample(
self,
cond: Dict,
uc: Union[Dict, None] = None,
batch_size: int = 16,
shape: Union[None, Tuple, List] = None,
**kwargs,
):
randn = torch.randn(batch_size, *shape).to(self.device)
denoiser = lambda input, sigma, c: self.denoiser(
self.model, input, sigma, c, **kwargs
)
samples = self.sampler(denoiser, randn, cond, uc=uc)
return samples
@torch.no_grad()
def log_conditionings(self, batch: Dict, n: int) -> Dict:
"""
Defines heuristics to log different conditionings.
These can be lists of strings (text-to-image), tensors, ints, ...
"""
image_h, image_w = batch[self.input_key].shape[2:]
log = dict()
for embedder in self.conditioner.embedders:
if (
(self.log_keys is None) or (embedder.input_key in self.log_keys)
) and not self.no_cond_log:
x = batch[embedder.input_key][:n]
if isinstance(x, torch.Tensor):
if x.dim() == 1:
# class-conditional, convert integer to string
x = [str(x[i].item()) for i in range(x.shape[0])]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
elif x.dim() == 2:
# size and crop cond and the like
x = [
"x".join([str(xx) for xx in x[i].tolist()])
for i in range(x.shape[0])
]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
else:
raise NotImplementedError()
elif isinstance(x, Union[List, ListConfig]):
if isinstance(x[0], str):
# strings
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
elif isinstance(x[0], Union[ListConfig, List]):
# # case: videos processed
x = [xx[0] for xx in x]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
else:
raise NotImplementedError()
else:
raise NotImplementedError()
log[embedder.input_key] = xc
return log
@torch.no_grad()
def log_images(
self,
batch: Dict,
N: int = 8,
sample: bool = True,
ucg_keys: List[str] = None,
**kwargs,
) -> Dict:
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
if ucg_keys:
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
)
else:
ucg_keys = conditioner_input_keys
log = dict()
x = self.get_input(batch)
c, uc = self.conditioner.get_unconditional_conditioning(
batch,
force_uc_zero_embeddings=ucg_keys
if len(self.conditioner.embedders) > 0
else [],
)
sampling_kwargs = {}
N = min(x.shape[0], N)
x = x.to(self.device)[:N]
log["inputs"] = x
z = self.encode_first_stage(x)
log["reconstructions"] = self.decode_first_stage(z)
log.update(self.log_conditionings(batch, N))
for k in c:
if isinstance(c[k], torch.Tensor):
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
if sample:
with self.ema_scope("Plotting"):
samples = self.sample(
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
)
samples = self.decode_first_stage(samples)
log["samples"] = samples
return log

6
sgm/modules/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
from .encoders.modules import GeneralConditioner
UNCONDITIONAL_CONFIG = {
"target": "sgm.modules.GeneralConditioner",
"params": {"emb_models": []},
}

947
sgm/modules/attention.py Normal file
View File

@@ -0,0 +1,947 @@
import math
from inspect import isfunction
from typing import Any, Optional
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from packaging import version
from torch import nn
if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True
from torch.backends.cuda import SDPBackend, sdp_kernel
BACKEND_MAP = {
SDPBackend.MATH: {
"enable_math": True,
"enable_flash": False,
"enable_mem_efficient": False,
},
SDPBackend.FLASH_ATTENTION: {
"enable_math": False,
"enable_flash": True,
"enable_mem_efficient": False,
},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False,
"enable_flash": False,
"enable_mem_efficient": True,
},
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
}
else:
from contextlib import nullcontext
SDP_IS_AVAILABLE = False
sdp_kernel = nullcontext
BACKEND_MAP = {}
print(
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
)
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
print("no module 'xformers'. Processing without...")
from .diffusionmodules.util import checkpoint
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_)
return x + h_
class CrossAttention(nn.Module):
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
backend=None,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.backend = backend
def forward(
self,
x,
context=None,
mask=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0,
):
h = self.heads
if additional_tokens is not None:
# get the number of masked tokens at the beginning of the output sequence
n_tokens_to_mask = additional_tokens.shape[1]
# add additional token
x = torch.cat([additional_tokens, x], dim=1)
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if n_times_crossframe_attn_in_self:
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
k = repeat(
k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
)
v = repeat(
v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
## old
"""
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
"""
## new
with sdp_kernel(**BACKEND_MAP[self.backend]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask
) # scale is dim_head ** -0.5 per default
del q, k, v
out = rearrange(out, "b h n d -> b n (h d)", h=h)
if additional_tokens is not None:
# remove additional token
out = out[:, n_tokens_to_mask:]
return self.to_out(out)
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
):
super().__init__()
print(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads with a dimension of {dim_head}."
)
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.attention_op: Optional[Any] = None
def forward(
self,
x,
context=None,
mask=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0,
):
if additional_tokens is not None:
# get the number of masked tokens at the beginning of the output sequence
n_tokens_to_mask = additional_tokens.shape[1]
# add additional token
x = torch.cat([additional_tokens, x], dim=1)
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if n_times_crossframe_attn_in_self:
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
k = repeat(
k[::n_times_crossframe_attn_in_self],
"b ... -> (b n) ...",
n=n_times_crossframe_attn_in_self,
)
v = repeat(
v[::n_times_crossframe_attn_in_self],
"b ... -> (b n) ...",
n=n_times_crossframe_attn_in_self,
)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op
)
# TODO: Use this directly in the attention operation, as a bias
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
if additional_tokens is not None:
# remove additional token
out = out[:, n_tokens_to_mask:]
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention, # vanilla attention
"softmax-xformers": MemoryEfficientCrossAttention, # ampere
}
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
attn_mode="softmax",
sdp_backend=None,
):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
print(
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
)
attn_mode = "softmax"
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
print(
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
)
if not XFORMERS_IS_AVAILABLE:
assert (
False
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
else:
print("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
attn_cls = self.ATTENTION_MODES[attn_mode]
if version.parse(torch.__version__) >= version.parse("2.0.0"):
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
else:
assert sdp_backend is None
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None,
backend=sdp_backend,
) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
backend=sdp_backend,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
if self.checkpoint:
print(f"{self.__class__.__name__} is using checkpointing")
def forward(
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
):
kwargs = {"x": x}
if context is not None:
kwargs.update({"context": context})
if additional_tokens is not None:
kwargs.update({"additional_tokens": additional_tokens})
if n_times_crossframe_attn_in_self:
kwargs.update(
{"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
)
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
def _forward(
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
):
x = (
self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
if not self.disable_self_attn
else 0,
)
+ x
)
x = (
self.attn2(
self.norm2(x), context=context, additional_tokens=additional_tokens
)
+ x
)
x = self.ff(self.norm3(x)) + x
return x
class BasicTransformerSingleLayerBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention, # vanilla attention
"softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
}
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
attn_mode="softmax",
):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
attn_cls = self.ATTENTION_MODES[attn_mode]
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim,
)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context) + x
x = self.ff(self.norm2(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
disable_self_attn=False,
use_linear=False,
attn_type="softmax",
use_checkpoint=True,
# sdp_backend=SDPBackend.FLASH_ATTENTION
sdp_backend=None,
):
super().__init__()
print(
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
)
from omegaconf import ListConfig
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
context_dim = [context_dim]
if exists(context_dim) and isinstance(context_dim, list):
if depth != len(context_dim):
print(
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
)
# depth does not match context dims.
assert all(
map(lambda x: x == context_dim[0], context_dim)
), "need homogenous context_dim to match depth automatically"
context_dim = depth * [context_dim[0]]
elif context_dim is None:
context_dim = [None] * depth
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
disable_self_attn=disable_self_attn,
attn_mode=attn_type,
checkpoint=use_checkpoint,
sdp_backend=sdp_backend,
)
for d in range(depth)
]
)
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
else:
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
if i > 0 and len(context) == 1:
i = 0 # use same context for each block
x = block(x, context=context[i])
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
def benchmark_attn():
# Lets define a helpful benchmarking function:
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
device = "cuda" if torch.cuda.is_available() else "cpu"
import torch.nn.functional as F
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32
dtype = torch.float16
query = torch.rand(
batch_size,
num_heads,
max_sequence_len,
embed_dimension,
device=device,
dtype=dtype,
)
key = torch.rand(
batch_size,
num_heads,
max_sequence_len,
embed_dimension,
device=device,
dtype=dtype,
)
value = torch.rand(
batch_size,
num_heads,
max_sequence_len,
embed_dimension,
device=device,
dtype=dtype,
)
print(f"q/k/v shape:", query.shape, key.shape, value.shape)
# Lets explore the speed of each of the 3 implementations
from torch.backends.cuda import SDPBackend, sdp_kernel
# Helpful arguments mapper
backend_map = {
SDPBackend.MATH: {
"enable_math": True,
"enable_flash": False,
"enable_mem_efficient": False,
},
SDPBackend.FLASH_ATTENTION: {
"enable_math": False,
"enable_flash": True,
"enable_mem_efficient": False,
},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False,
"enable_flash": False,
"enable_mem_efficient": True,
},
}
from torch.profiler import ProfilerActivity, profile, record_function
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
print(
f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("Default detailed stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print(
f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)
with sdp_kernel(**backend_map[SDPBackend.MATH]):
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("Math implmentation stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
try:
print(
f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("FlashAttention stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
try:
print(
f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)
except RuntimeError:
print("EfficientAttention is not supported. See warnings for reasons.")
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("EfficientAttention stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
def run_model(model, x, context):
return model(x, context)
def benchmark_transformer_blocks():
device = "cuda" if torch.cuda.is_available() else "cpu"
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
checkpoint = True
compile = False
batch_size = 32
h, w = 64, 64
context_len = 77
embed_dimension = 1024
context_dim = 1024
d_head = 64
transformer_depth = 4
n_heads = embed_dimension // d_head
dtype = torch.float16
model_native = SpatialTransformer(
embed_dimension,
n_heads,
d_head,
context_dim=context_dim,
use_linear=True,
use_checkpoint=checkpoint,
attn_type="softmax",
depth=transformer_depth,
sdp_backend=SDPBackend.FLASH_ATTENTION,
).to(device)
model_efficient_attn = SpatialTransformer(
embed_dimension,
n_heads,
d_head,
context_dim=context_dim,
use_linear=True,
depth=transformer_depth,
use_checkpoint=checkpoint,
attn_type="softmax-xformers",
).to(device)
if not checkpoint and compile:
print("compiling models")
model_native = torch.compile(model_native)
model_efficient_attn = torch.compile(model_efficient_attn)
x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
from torch.profiler import ProfilerActivity, profile, record_function
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
with torch.autocast("cuda"):
print(
f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
)
print(
f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
)
print(75 * "+")
print("NATIVE")
print(75 * "+")
torch.cuda.reset_peak_memory_stats()
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("NativeAttention stats"):
for _ in range(25):
model_native(x, c)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
print(75 * "+")
print("Xformers")
print(75 * "+")
torch.cuda.reset_peak_memory_stats()
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("xformers stats"):
for _ in range(25):
model_efficient_attn(x, c)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
def test01():
# conv1x1 vs linear
from ..util import count_params
conv = nn.Conv2d(3, 32, kernel_size=1).cuda()
print(count_params(conv))
linear = torch.nn.Linear(3, 32).cuda()
print(count_params(linear))
print(conv.weight.shape)
# use same initialization
linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
linear.bias = torch.nn.Parameter(conv.bias)
print(linear.weight.shape)
x = torch.randn(11, 3, 64, 64).cuda()
xr = rearrange(x, "b c h w -> b (h w) c").contiguous()
print(xr.shape)
out_linear = linear(xr)
print(out_linear.mean(), out_linear.shape)
out_conv = conv(x)
print(out_conv.mean(), out_conv.shape)
print("done with test01.\n")
def test02():
# try cosine flash attention
import time
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
print("testing cosine flash attention...")
DIM = 1024
SEQLEN = 4096
BS = 16
print(" softmax (vanilla) first...")
model = BasicTransformerBlock(
dim=DIM,
n_heads=16,
d_head=64,
dropout=0.0,
context_dim=None,
attn_mode="softmax",
).cuda()
try:
x = torch.randn(BS, SEQLEN, DIM).cuda()
tic = time.time()
y = model(x)
toc = time.time()
print(y.shape, toc - tic)
except RuntimeError as e:
# likely oom
print(str(e))
print("\n now flash-cosine...")
model = BasicTransformerBlock(
dim=DIM,
n_heads=16,
d_head=64,
dropout=0.0,
context_dim=None,
attn_mode="flash-cosine",
).cuda()
x = torch.randn(BS, SEQLEN, DIM).cuda()
tic = time.time()
y = model(x)
toc = time.time()
print(y.shape, toc - tic)
print("done with test02.\n")
if __name__ == "__main__":
# test01()
# test02()
# test03()
# benchmark_attn()
benchmark_transformer_blocks()
print("done.")

View File

View File

@@ -0,0 +1,246 @@
from typing import Any, Union
import torch
import torch.nn as nn
from einops import rearrange
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
from ....util import default, instantiate_from_config
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
class LatentLPIPS(nn.Module):
def __init__(
self,
decoder_config,
perceptual_weight=1.0,
latent_weight=1.0,
scale_input_to_tgt_size=False,
scale_tgt_to_input_size=False,
perceptual_weight_on_inputs=0.0,
):
super().__init__()
self.scale_input_to_tgt_size = scale_input_to_tgt_size
self.scale_tgt_to_input_size = scale_tgt_to_input_size
self.init_decoder(decoder_config)
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
self.latent_weight = latent_weight
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
def init_decoder(self, config):
self.decoder = instantiate_from_config(config)
if hasattr(self.decoder, "encoder"):
del self.decoder.encoder
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
log = dict()
loss = (latent_inputs - latent_predictions) ** 2
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
image_reconstructions = None
if self.perceptual_weight > 0.0:
image_reconstructions = self.decoder.decode(latent_predictions)
image_targets = self.decoder.decode(latent_inputs)
perceptual_loss = self.perceptual_loss(
image_targets.contiguous(), image_reconstructions.contiguous()
)
loss = (
self.latent_weight * loss.mean()
+ self.perceptual_weight * perceptual_loss.mean()
)
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
if self.perceptual_weight_on_inputs > 0.0:
image_reconstructions = default(
image_reconstructions, self.decoder.decode(latent_predictions)
)
if self.scale_input_to_tgt_size:
image_inputs = torch.nn.functional.interpolate(
image_inputs,
image_reconstructions.shape[2:],
mode="bicubic",
antialias=True,
)
elif self.scale_tgt_to_input_size:
image_reconstructions = torch.nn.functional.interpolate(
image_reconstructions,
image_inputs.shape[2:],
mode="bicubic",
antialias=True,
)
perceptual_loss2 = self.perceptual_loss(
image_inputs.contiguous(), image_reconstructions.contiguous()
)
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
return loss, log
class GeneralLPIPSWithDiscriminator(nn.Module):
def __init__(
self,
disc_start: int,
logvar_init: float = 0.0,
pixelloss_weight=1.0,
disc_num_layers: int = 3,
disc_in_channels: int = 3,
disc_factor: float = 1.0,
disc_weight: float = 1.0,
perceptual_weight: float = 1.0,
disc_loss: str = "hinge",
scale_input_to_tgt_size: bool = False,
dims: int = 2,
learn_logvar: bool = False,
regularization_weights: Union[None, dict] = None,
):
super().__init__()
self.dims = dims
if self.dims > 2:
print(
f"running with dims={dims}. This means that for perceptual loss calculation, "
f"the LPIPS loss will be applied to each frame independently. "
)
self.scale_input_to_tgt_size = scale_input_to_tgt_size
assert disc_loss in ["hinge", "vanilla"]
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS().eval()
self.perceptual_weight = perceptual_weight
# output log variance
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
self.learn_logvar = learn_logvar
self.discriminator = NLayerDiscriminator(
input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False
).apply(weights_init)
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.regularization_weights = default(regularization_weights, {})
def get_trainable_parameters(self) -> Any:
return self.discriminator.parameters()
def get_trainable_autoencoder_parameters(self) -> Any:
if self.learn_logvar:
yield self.logvar
yield from ()
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
regularization_log,
inputs,
reconstructions,
optimizer_idx,
global_step,
last_layer=None,
split="train",
weights=None,
):
if self.scale_input_to_tgt_size:
inputs = torch.nn.functional.interpolate(
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
)
if self.dims > 2:
inputs, reconstructions = map(
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
(inputs, reconstructions),
)
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(
inputs.contiguous(), reconstructions.contiguous()
)
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
weighted_nll_loss = weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
# now the GAN part
if optimizer_idx == 0:
# generator update
logits_fake = self.discriminator(reconstructions.contiguous())
g_loss = -torch.mean(logits_fake)
if self.disc_factor > 0.0:
try:
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
loss = weighted_nll_loss + d_weight * disc_factor * g_loss
log = dict()
for k in regularization_log:
if k in self.regularization_weights:
loss = loss + self.regularization_weights[k] * regularization_log[k]
log[f"{split}/{k}"] = regularization_log[k].detach().mean()
log.update(
{
"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/logvar".format(split): self.logvar.detach(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
)
return loss, log
if optimizer_idx == 1:
# second pass for discriminator update
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean(),
}
return d_loss, log

View File

@@ -0,0 +1,53 @@
from abc import abstractmethod
from typing import Any, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ....modules.distributions.distributions import DiagonalGaussianDistribution
class AbstractRegularizer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
raise NotImplementedError()
@abstractmethod
def get_trainable_parameters(self) -> Any:
raise NotImplementedError()
class DiagonalGaussianRegularizer(AbstractRegularizer):
def __init__(self, sample: bool = True):
super().__init__()
self.sample = sample
def get_trainable_parameters(self) -> Any:
yield from ()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
log = dict()
posterior = DiagonalGaussianDistribution(z)
if self.sample:
z = posterior.sample()
else:
z = posterior.mode()
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log
def measure_perplexity(predicted_indices, num_centroids):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = (
F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use

View File

@@ -0,0 +1,7 @@
from .denoiser import Denoiser
from .discretizer import Discretization
from .loss import StandardDiffusionLoss
from .model import Model, Encoder, Decoder
from .openaimodel import UNetModel
from .sampling import BaseDiffusionSampler
from .wrappers import OpenAIWrapper

View File

@@ -0,0 +1,63 @@
import torch.nn as nn
from ...util import append_dims, instantiate_from_config
class Denoiser(nn.Module):
def __init__(self, weighting_config, scaling_config):
super().__init__()
self.weighting = instantiate_from_config(weighting_config)
self.scaling = instantiate_from_config(scaling_config)
def possibly_quantize_sigma(self, sigma):
return sigma
def possibly_quantize_c_noise(self, c_noise):
return c_noise
def w(self, sigma):
return self.weighting(sigma)
def __call__(self, network, input, sigma, cond):
sigma = self.possibly_quantize_sigma(sigma)
sigma_shape = sigma.shape
sigma = append_dims(sigma, input.ndim)
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
return network(input * c_in, c_noise, cond) * c_out + input * c_skip
class DiscreteDenoiser(Denoiser):
def __init__(
self,
weighting_config,
scaling_config,
num_idx,
discretization_config,
do_append_zero=False,
quantize_c_noise=True,
flip=True,
):
super().__init__(weighting_config, scaling_config)
sigmas = instantiate_from_config(discretization_config)(
num_idx, do_append_zero=do_append_zero, flip=flip
)
self.register_buffer("sigmas", sigmas)
self.quantize_c_noise = quantize_c_noise
def sigma_to_idx(self, sigma):
dists = sigma - self.sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape)
def idx_to_sigma(self, idx):
return self.sigmas[idx]
def possibly_quantize_sigma(self, sigma):
return self.idx_to_sigma(self.sigma_to_idx(sigma))
def possibly_quantize_c_noise(self, c_noise):
if self.quantize_c_noise:
return self.sigma_to_idx(c_noise)
else:
return c_noise

View File

@@ -0,0 +1,31 @@
import torch
class EDMScaling:
def __init__(self, sigma_data=0.5):
self.sigma_data = sigma_data
def __call__(self, sigma):
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
c_noise = 0.25 * sigma.log()
return c_skip, c_out, c_in, c_noise
class EpsScaling:
def __call__(self, sigma):
c_skip = torch.ones_like(sigma, device=sigma.device)
c_out = -sigma
c_in = 1 / (sigma**2 + 1.0) ** 0.5
c_noise = sigma.clone()
return c_skip, c_out, c_in, c_noise
class VScaling:
def __call__(self, sigma):
c_skip = 1.0 / (sigma**2 + 1.0)
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
c_noise = sigma.clone()
return c_skip, c_out, c_in, c_noise

View File

@@ -0,0 +1,24 @@
import torch
class UnitWeighting:
def __call__(self, sigma):
return torch.ones_like(sigma, device=sigma.device)
class EDMWeighting:
def __init__(self, sigma_data=0.5):
self.sigma_data = sigma_data
def __call__(self, sigma):
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
class VWeighting(EDMWeighting):
def __init__(self):
super().__init__(sigma_data=1.0)
class EpsWeighting:
def __call__(self, sigma):
return sigma**-2.0

View File

@@ -0,0 +1,65 @@
import torch
import numpy as np
from functools import partial
from ...util import append_zero
from ...modules.diffusionmodules.util import make_beta_schedule
class Discretization:
def __call__(self, n, do_append_zero=True, device="cuda", flip=False):
sigmas = self.get_sigmas(n, device)
sigmas = append_zero(sigmas) if do_append_zero else sigmas
return sigmas if not flip else torch.flip(sigmas, (0,))
class EDMDiscretization(Discretization):
def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.rho = rho
def get_sigmas(self, n, device):
ramp = torch.linspace(0, 1, n, device=device)
min_inv_rho = self.sigma_min ** (1 / self.rho)
max_inv_rho = self.sigma_max ** (1 / self.rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
return sigmas
class LegacyDDPMDiscretization(Discretization):
def __init__(
self,
linear_start=0.00085,
linear_end=0.0120,
num_timesteps=1000,
legacy_range=True,
):
self.num_timesteps = num_timesteps
betas = make_beta_schedule(
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
)
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.to_torch = partial(torch.tensor, dtype=torch.float32)
self.legacy_range = legacy_range
def get_sigmas(self, n, device):
if n < self.num_timesteps:
c = self.num_timesteps // n
if self.legacy_range:
timesteps = np.asarray(list(range(0, self.num_timesteps, c)))
timesteps += 1 # Legacy LDM Hack
else:
timesteps = np.asarray(list(range(0, self.num_timesteps + 1, c)))
timesteps -= 1
timesteps = timesteps[1:]
alphas_cumprod = self.alphas_cumprod[timesteps]
else:
alphas_cumprod = self.alphas_cumprod
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
return torch.flip(sigmas, (0,))

View File

@@ -0,0 +1,53 @@
from functools import partial
import torch
from ...util import default, instantiate_from_config
class VanillaCFG:
"""
implements parallelized CFG
"""
def __init__(self, scale, dyn_thresh_config=None):
scale_schedule = lambda scale, sigma: scale # independent of step
self.scale_schedule = partial(scale_schedule, scale)
self.dyn_thresh = instantiate_from_config(
default(
dyn_thresh_config,
{
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
},
)
)
def __call__(self, x, sigma):
x_u, x_c = x.chunk(2)
scale_value = self.scale_schedule(sigma)
x_pred = self.dyn_thresh(x_u, x_c, scale_value)
return x_pred
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
if k in ["vector", "crossattn", "concat"]:
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
class IdentityGuider:
def __call__(self, x, sigma):
return x
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
c_out[k] = c[k]
return x, s, c_out

View File

@@ -0,0 +1,69 @@
from typing import List, Optional, Union
import torch
import torch.nn as nn
from omegaconf import ListConfig
from taming.modules.losses.lpips import LPIPS
from ...util import append_dims, instantiate_from_config
class StandardDiffusionLoss(nn.Module):
def __init__(
self,
sigma_sampler_config,
type="l2",
offset_noise_level=0.0,
batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
):
super().__init__()
assert type in ["l2", "l1", "lpips"]
self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
self.type = type
self.offset_noise_level = offset_noise_level
if type == "lpips":
self.lpips = LPIPS().eval()
if not batch2model_keys:
batch2model_keys = []
if isinstance(batch2model_keys, str):
batch2model_keys = [batch2model_keys]
self.batch2model_keys = set(batch2model_keys)
def __call__(self, network, denoiser, conditioner, input, batch):
cond = conditioner(batch)
additional_model_inputs = {
key: batch[key] for key in self.batch2model_keys.intersection(batch)
}
sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
noise = torch.randn_like(input)
if self.offset_noise_level > 0.0:
noise = noise + self.offset_noise_level * append_dims(
torch.randn(input.shape[0], device=input.device), input.ndim
)
noised_input = input + noise * append_dims(sigmas, input.ndim)
model_output = denoiser(
network, noised_input, sigmas, cond, **additional_model_inputs
)
w = append_dims(denoiser.w(sigmas), input.ndim)
return self.get_loss(model_output, input, w)
def get_loss(self, model_output, target, w):
if self.type == "l2":
return torch.mean(
(w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
)
elif self.type == "l1":
return torch.mean(
(w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
)
elif self.type == "lpips":
loss = self.lpips(model_output, target).reshape(-1)
return loss

View File

@@ -0,0 +1,743 @@
# pytorch_diffusion + derived encoder decoder
import math
from typing import Any, Callable, Optional
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from packaging import version
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
print("no module 'xformers'. Processing without...")
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
def __init__(self, in_channels):
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def attention(self, h_: torch.Tensor) -> torch.Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(
lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
)
h_ = torch.nn.functional.scaled_dot_product_attention(
q, k, v
) # scale is dim ** -0.5 per default
# compute attention
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x, **kwargs):
h_ = x
h_ = self.attention(h_)
h_ = self.proj_out(h_)
return x + h_
class MemoryEfficientAttnBlock(nn.Module):
"""
Uses xformers efficient implementation,
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
Note: this is a single-head self-attention operation
"""
#
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.attention_op: Optional[Any] = None
def attention(self, h_: torch.Tensor) -> torch.Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op
)
out = (
out.unsqueeze(0)
.reshape(B, 1, out.shape[1], C)
.permute(0, 2, 1, 3)
.reshape(B, out.shape[1], C)
)
return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
def forward(self, x, **kwargs):
h_ = x
h_ = self.attention(h_)
h_ = self.proj_out(h_)
return x + h_
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def forward(self, x, context=None, mask=None, **unused_kwargs):
b, c, h, w = x.shape
x = rearrange(x, "b c h w -> b (h w) c")
out = super().forward(x, context=context, mask=mask)
out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
return x + out
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in [
"vanilla",
"vanilla-xformers",
"memory-efficient-cross-attn",
"linear",
"none",
], f"attn_type {attn_type} unknown"
if (
version.parse(torch.__version__) < version.parse("2.0.0")
and attn_type != "none"
):
assert XFORMERS_IS_AVAILABLE, (
f"We do not support vanilla attention in {torch.__version__} anymore, "
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
)
attn_type = "vanilla-xformers"
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
return AttnBlock(in_channels)
elif attn_type == "vanilla-xformers":
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return MemoryEfficientAttnBlock(in_channels)
elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
return LinAttnBlock(in_channels)
class Model(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
use_timestep=True,
use_linear_attn=False,
attn_type="vanilla",
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = self.ch * 4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.use_timestep = use_timestep
if self.use_timestep:
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList(
[
torch.nn.Linear(self.ch, self.temb_ch),
torch.nn.Linear(self.temb_ch, self.temb_ch),
]
)
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
skip_in = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
if i_block == self.num_res_blocks:
skip_in = ch * in_ch_mult[i_level]
block.append(
ResnetBlock(
in_channels=block_in + skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
def forward(self, x, t=None, context=None):
# assert x.shape[2] == x.shape[3] == self.resolution
if context is not None:
# assume aligned context, cat along channel axis
x = torch.cat((x, context), dim=1)
if self.use_timestep:
# timestep embedding
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](
torch.cat([h, hs.pop()], dim=1), temb
)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.weight
class Encoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type="vanilla",
**ignore_kwargs,
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type="vanilla",
**ignorekwargs,
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
make_attn_cls = self._make_attn()
make_resblock_cls = self._make_resblock()
make_conv_cls = self._make_conv()
# z to block_in
self.conv_in = torch.nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# middle
self.mid = nn.Module()
self.mid.block_1 = make_resblock_cls(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
self.mid.block_2 = make_resblock_cls(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
make_resblock_cls(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn_cls(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = make_conv_cls(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
def _make_attn(self) -> Callable:
return make_attn
def _make_resblock(self) -> Callable:
return ResnetBlock
def _make_conv(self) -> Callable:
return torch.nn.Conv2d
def get_last_layer(self, **kwargs):
return self.conv_out.weight
def forward(self, z, **kwargs):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb, **kwargs)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, **kwargs)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, **kwargs)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h, **kwargs)
if self.tanh_out:
h = torch.tanh(h)
return h

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,365 @@
"""
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
"""
from typing import Dict, Union
import torch
from omegaconf import ListConfig, OmegaConf
from tqdm import tqdm
from ...modules.diffusionmodules.sampling_utils import (
get_ancestral_step,
linear_multistep_coeff,
to_d,
to_neg_log_sigma,
to_sigma,
)
from ...util import append_dims, default, instantiate_from_config
DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
class BaseDiffusionSampler:
def __init__(
self,
discretization_config: Union[Dict, ListConfig, OmegaConf],
num_steps: Union[int, None] = None,
guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
verbose: bool = False,
device: str = "cuda",
):
self.num_steps = num_steps
self.discretization = instantiate_from_config(discretization_config)
self.guider = instantiate_from_config(
default(
guider_config,
DEFAULT_GUIDER,
)
)
self.verbose = verbose
self.device = device
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
sigmas = self.discretization(
self.num_steps if num_steps is None else num_steps, device=self.device
)
uc = default(uc, cond)
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
num_sigmas = len(sigmas)
s_in = x.new_ones([x.shape[0]])
return x, s_in, sigmas, num_sigmas, cond, uc
def denoise(self, x, denoiser, sigma, cond, uc):
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
denoised = self.guider(denoised, sigma)
return denoised
def get_sigma_gen(self, num_sigmas):
sigma_generator = range(num_sigmas - 1)
if self.verbose:
print("#" * 30, " Sampling setting ", "#" * 30)
print(f"Sampler: {self.__class__.__name__}")
print(f"Discretization: {self.discretization.__class__.__name__}")
print(f"Guider: {self.guider.__class__.__name__}")
sigma_generator = tqdm(
sigma_generator,
total=num_sigmas,
desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
)
return sigma_generator
class SingleStepDiffusionSampler(BaseDiffusionSampler):
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
raise NotImplementedError
def euler_step(self, x, d, dt):
return x + dt * d
class EDMSampler(SingleStepDiffusionSampler):
def __init__(
self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.s_churn = s_churn
self.s_tmin = s_tmin
self.s_tmax = s_tmax
self.s_noise = s_noise
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
sigma_hat = sigma * (gamma + 1.0)
if gamma > 0:
eps = torch.randn_like(x) * self.s_noise
x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
d = to_d(x, sigma_hat, denoised)
dt = append_dims(next_sigma - sigma_hat, x.ndim)
euler_step = self.euler_step(x, d, dt)
x = self.possible_correction_step(
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
)
return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
for i in self.get_sigma_gen(num_sigmas):
gamma = (
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
if self.s_tmin <= sigmas[i] <= self.s_tmax
else 0.0
)
x = self.sampler_step(
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
cond,
uc,
gamma,
)
return x
class AncestralSampler(SingleStepDiffusionSampler):
def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.eta = eta
self.s_noise = s_noise
self.noise_sampler = lambda x: torch.randn_like(x)
def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
d = to_d(x, sigma, denoised)
dt = append_dims(sigma_down - sigma, x.ndim)
return self.euler_step(x, d, dt)
def ancestral_step(self, x, sigma, next_sigma, sigma_up):
x = torch.where(
append_dims(next_sigma, x.ndim) > 0.0,
x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
x,
)
return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
for i in self.get_sigma_gen(num_sigmas):
x = self.sampler_step(
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
cond,
uc,
)
return x
class LinearMultistepSampler(BaseDiffusionSampler):
def __init__(
self,
order=4,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.order = order
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
ds = []
sigmas_cpu = sigmas.detach().cpu().numpy()
for i in self.get_sigma_gen(num_sigmas):
sigma = s_in * sigmas[i]
denoised = denoiser(
*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
)
denoised = self.guider(denoised, sigma)
d = to_d(x, sigma, denoised)
ds.append(d)
if len(ds) > self.order:
ds.pop(0)
cur_order = min(i + 1, self.order)
coeffs = [
linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
for j in range(cur_order)
]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x
class EulerEDMSampler(EDMSampler):
def possible_correction_step(
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
):
return euler_step
class HeunEDMSampler(EDMSampler):
def possible_correction_step(
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
):
if torch.sum(next_sigma) < 1e-14:
# Save a network evaluation if all noise levels are 0
return euler_step
else:
denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
d_new = to_d(euler_step, next_sigma, denoised)
d_prime = (d + d_new) / 2.0
# apply correction if noise level is not 0
x = torch.where(
append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
)
return x
class EulerAncestralSampler(AncestralSampler):
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
denoised = self.denoise(x, denoiser, sigma, cond, uc)
x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
return x
class DPMPP2SAncestralSampler(AncestralSampler):
def get_variables(self, sigma, sigma_down):
t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
h = t_next - t
s = t + 0.5 * h
return h, s, t, t_next
def get_mult(self, h, s, t, t_next):
mult1 = to_sigma(s) / to_sigma(t)
mult2 = (-0.5 * h).expm1()
mult3 = to_sigma(t_next) / to_sigma(t)
mult4 = (-h).expm1()
return mult1, mult2, mult3, mult4
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
denoised = self.denoise(x, denoiser, sigma, cond, uc)
x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
if torch.sum(sigma_down) < 1e-14:
# Save a network evaluation if all noise levels are 0
x = x_euler
else:
h, s, t, t_next = self.get_variables(sigma, sigma_down)
mult = [
append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
]
x2 = mult[0] * x - mult[1] * denoised
denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
x_dpmpp2s = mult[2] * x - mult[3] * denoised2
# apply correction if noise level is not 0
x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
return x
class DPMPP2MSampler(BaseDiffusionSampler):
def get_variables(self, sigma, next_sigma, previous_sigma=None):
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
h = t_next - t
if previous_sigma is not None:
h_last = t - to_neg_log_sigma(previous_sigma)
r = h_last / h
return h, r, t, t_next
else:
return h, None, t, t_next
def get_mult(self, h, r, t, t_next, previous_sigma):
mult1 = to_sigma(t_next) / to_sigma(t)
mult2 = (-h).expm1()
if previous_sigma is not None:
mult3 = 1 + 1 / (2 * r)
mult4 = 1 / (2 * r)
return mult1, mult2, mult3, mult4
else:
return mult1, mult2
def sampler_step(
self,
old_denoised,
previous_sigma,
sigma,
next_sigma,
denoiser,
x,
cond,
uc=None,
):
denoised = self.denoise(x, denoiser, sigma, cond, uc)
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
mult = [
append_dims(mult, x.ndim)
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
]
x_standard = mult[0] * x - mult[1] * denoised
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
# Save a network evaluation if all noise levels are 0 or on the first step
return x_standard, denoised
else:
denoised_d = mult[2] * denoised - mult[3] * old_denoised
x_advanced = mult[0] * x - mult[1] * denoised_d
# apply correction if noise level is not 0 and not first step
x = torch.where(
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
)
return x, denoised
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
old_denoised = None
for i in self.get_sigma_gen(num_sigmas):
x, old_denoised = self.sampler_step(
old_denoised,
None if i == 0 else s_in * sigmas[i - 1],
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
cond,
uc=uc,
)
return x

View File

@@ -0,0 +1,48 @@
import torch
from scipy import integrate
from ...util import append_dims
class NoDynamicThresholding:
def __call__(self, uncond, cond, scale):
return uncond + scale * (cond - uncond)
def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
if order - 1 > i:
raise ValueError(f"Order {order} too high for step {i}")
def fn(tau):
prod = 1.0
for k in range(order):
if j == k:
continue
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
return prod
return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
if not eta:
return sigma_to, 0.0
sigma_up = torch.minimum(
sigma_to,
eta
* (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
)
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
return sigma_down, sigma_up
def to_d(x, sigma, denoised):
return (x - denoised) / append_dims(sigma, x.ndim)
def to_neg_log_sigma(sigma):
return sigma.log().neg()
def to_sigma(neg_log_sigma):
return neg_log_sigma.neg().exp()

View File

@@ -0,0 +1,31 @@
import torch
from ...util import default, instantiate_from_config
class EDMSampling:
def __init__(self, p_mean=-1.2, p_std=1.2):
self.p_mean = p_mean
self.p_std = p_std
def __call__(self, n_samples, rand=None):
log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
return log_sigma.exp()
class DiscreteSampling:
def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True):
self.num_idx = num_idx
self.sigmas = instantiate_from_config(discretization_config)(
num_idx, do_append_zero=do_append_zero, flip=flip
)
def idx_to_sigma(self, idx):
return self.sigmas[idx]
def __call__(self, n_samples, rand=None):
idx = default(
rand,
torch.randint(0, self.num_idx, (n_samples,)),
)
return self.idx_to_sigma(idx)

View File

@@ -0,0 +1,308 @@
"""
adopted from
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
and
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
and
https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
thanks!
"""
import math
import torch
import torch.nn as nn
from einops import repeat
def make_beta_schedule(
schedule,
n_timestep,
linear_start=1e-4,
linear_end=2e-2,
):
if schedule == "linear":
betas = (
torch.linspace(
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
)
** 2
)
return betas.numpy()
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def mixed_checkpoint(func, inputs: dict, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
it also works with non-tensor inputs
:param func: the function to evaluate.
:param inputs: the argument dictionary to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
tensor_inputs = [
inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)
]
non_tensor_keys = [
key for key in inputs if not isinstance(inputs[key], torch.Tensor)
]
non_tensor_inputs = [
inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)
]
args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
return MixedCheckpointFunction.apply(
func,
len(tensor_inputs),
len(non_tensor_inputs),
tensor_keys,
non_tensor_keys,
*args,
)
else:
return func(**inputs)
class MixedCheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
run_function,
length_tensors,
length_non_tensors,
tensor_keys,
non_tensor_keys,
*args,
):
ctx.end_tensors = length_tensors
ctx.end_non_tensors = length_tensors + length_non_tensors
ctx.gpu_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled(),
}
assert (
len(tensor_keys) == length_tensors
and len(non_tensor_keys) == length_non_tensors
)
ctx.input_tensors = {
key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))
}
ctx.input_non_tensors = {
key: val
for (key, val) in zip(
non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])
)
}
ctx.run_function = run_function
ctx.input_params = list(args[ctx.end_non_tensors :])
with torch.no_grad():
output_tensors = ctx.run_function(
**ctx.input_tensors, **ctx.input_non_tensors
)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
# additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
ctx.input_tensors = {
key: ctx.input_tensors[key].detach().requires_grad_(True)
for key in ctx.input_tensors
}
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = {
key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])
for key in ctx.input_tensors
}
# shallow_copies.update(additional_args)
output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
input_grads = torch.autograd.grad(
output_tensors,
list(ctx.input_tensors.values()) + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (
(None, None, None, None, None)
+ input_grads[: ctx.end_tensors]
+ (None,) * (ctx.end_non_tensors - ctx.end_tensors)
+ input_grads[ctx.end_tensors :]
)
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
ctx.gpu_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled(),
}
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else:
embedding = repeat(timesteps, "b -> b d", d=dim)
return embedding
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")

View File

@@ -0,0 +1,34 @@
import torch
import torch.nn as nn
from packaging import version
OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
class IdentityWrapper(nn.Module):
def __init__(self, diffusion_model, compile_model: bool = False):
super().__init__()
compile = (
torch.compile
if (version.parse(torch.__version__) >= version.parse("2.0.0"))
and compile_model
else lambda x: x
)
self.diffusion_model = compile(diffusion_model)
def forward(self, *args, **kwargs):
return self.diffusion_model(*args, **kwargs)
class OpenAIWrapper(IdentityWrapper):
def forward(
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
) -> torch.Tensor:
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
return self.diffusion_model(
x,
timesteps=t,
context=c.get("crossattn", None),
y=c.get("vector", None),
**kwargs
)

View File

View File

@@ -0,0 +1,102 @@
import torch
import numpy as np
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(
device=self.parameters.device
)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(
device=self.parameters.device
)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)

86
sgm/modules/ema.py Normal file
View File

@@ -0,0 +1,86 @@
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError("Decay must be between 0 and 1")
self.m_name2s_name = {}
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
"num_updates",
torch.tensor(0, dtype=torch.int)
if use_num_upates
else torch.tensor(-1, dtype=torch.int),
)
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace(".", "")
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def reset_num_updates(self):
del self.num_updates
self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)

View File

View File

@@ -0,0 +1,960 @@
from contextlib import nullcontext
from functools import partial
from typing import Dict, List, Optional, Tuple, Union
import kornia
import numpy as np
import open_clip
import torch
import torch.nn as nn
from einops import rearrange, repeat
from omegaconf import ListConfig
from torch.utils.checkpoint import checkpoint
from transformers import (
ByT5Tokenizer,
CLIPTextModel,
CLIPTokenizer,
T5EncoderModel,
T5Tokenizer,
)
from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
from ...modules.diffusionmodules.model import Encoder
from ...modules.diffusionmodules.openaimodel import Timestep
from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
from ...modules.distributions.distributions import DiagonalGaussianDistribution
from ...util import (
autocast,
count_params,
default,
disabled_train,
expand_dims_like,
instantiate_from_config,
)
class AbstractEmbModel(nn.Module):
def __init__(self):
super().__init__()
self._is_trainable = None
self._ucg_rate = None
self._input_key = None
@property
def is_trainable(self) -> bool:
return self._is_trainable
@property
def ucg_rate(self) -> Union[float, torch.Tensor]:
return self._ucg_rate
@property
def input_key(self) -> str:
return self._input_key
@is_trainable.setter
def is_trainable(self, value: bool):
self._is_trainable = value
@ucg_rate.setter
def ucg_rate(self, value: Union[float, torch.Tensor]):
self._ucg_rate = value
@input_key.setter
def input_key(self, value: str):
self._input_key = value
@is_trainable.deleter
def is_trainable(self):
del self._is_trainable
@ucg_rate.deleter
def ucg_rate(self):
del self._ucg_rate
@input_key.deleter
def input_key(self):
del self._input_key
class GeneralConditioner(nn.Module):
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
def __init__(self, emb_models: Union[List, ListConfig]):
super().__init__()
embedders = []
for n, embconfig in enumerate(emb_models):
embedder = instantiate_from_config(embconfig)
assert isinstance(
embedder, AbstractEmbModel
), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
embedder.is_trainable = embconfig.get("is_trainable", False)
embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
if not embedder.is_trainable:
embedder.train = disabled_train
for param in embedder.parameters():
param.requires_grad = False
embedder.eval()
print(
f"Initialized embedder #{n}: {embedder.__class__.__name__} "
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
)
if "input_key" in embconfig:
embedder.input_key = embconfig["input_key"]
elif "input_keys" in embconfig:
embedder.input_keys = embconfig["input_keys"]
else:
raise KeyError(
f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
)
embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
if embedder.legacy_ucg_val is not None:
embedder.ucg_prng = np.random.RandomState()
embedders.append(embedder)
self.embedders = nn.ModuleList(embedders)
def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
assert embedder.legacy_ucg_val is not None
p = embedder.ucg_rate
val = embedder.legacy_ucg_val
for i in range(len(batch[embedder.input_key])):
if embedder.ucg_prng.choice(2, p=[1 - p, p]):
batch[embedder.input_key][i] = val
return batch
def forward(
self, batch: Dict, force_zero_embeddings: Optional[List] = None
) -> Dict:
output = dict()
if force_zero_embeddings is None:
force_zero_embeddings = []
for embedder in self.embedders:
embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
with embedding_context():
if hasattr(embedder, "input_key") and (embedder.input_key is not None):
if embedder.legacy_ucg_val is not None:
batch = self.possibly_get_ucg_val(embedder, batch)
emb_out = embedder(batch[embedder.input_key])
elif hasattr(embedder, "input_keys"):
emb_out = embedder(*[batch[k] for k in embedder.input_keys])
assert isinstance(
emb_out, (torch.Tensor, list, tuple)
), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
if not isinstance(emb_out, (list, tuple)):
emb_out = [emb_out]
for emb in emb_out:
out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
emb = (
expand_dims_like(
torch.bernoulli(
(1.0 - embedder.ucg_rate)
* torch.ones(emb.shape[0], device=emb.device)
),
emb,
)
* emb
)
if (
hasattr(embedder, "input_key")
and embedder.input_key in force_zero_embeddings
):
emb = torch.zeros_like(emb)
if out_key in output:
output[out_key] = torch.cat(
(output[out_key], emb), self.KEY2CATDIM[out_key]
)
else:
output[out_key] = emb
return output
def get_unconditional_conditioning(
self, batch_c, batch_uc=None, force_uc_zero_embeddings=None
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
ucg_rates = list()
for embedder in self.embedders:
ucg_rates.append(embedder.ucg_rate)
embedder.ucg_rate = 0.0
c = self(batch_c)
uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
for embedder, rate in zip(self.embedders, ucg_rates):
embedder.ucg_rate = rate
return c, uc
class InceptionV3(nn.Module):
"""Wrapper around the https://github.com/mseitzer/pytorch-fid inception
port with an additional squeeze at the end"""
def __init__(self, normalize_input=False, **kwargs):
super().__init__()
from pytorch_fid import inception
kwargs["resize_input"] = True
self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)
def forward(self, inp):
# inp = kornia.geometry.resize(inp, (299, 299),
# interpolation='bicubic',
# align_corners=False,
# antialias=True)
# inp = inp.clamp(min=-1, max=1)
outp = self.model(inp)
if len(outp) == 1:
return outp[0].squeeze()
return outp
class IdentityEncoder(AbstractEmbModel):
def encode(self, x):
return x
def forward(self, x):
return x
class ClassEmbedder(AbstractEmbModel):
def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False):
super().__init__()
self.embedding = nn.Embedding(n_classes, embed_dim)
self.n_classes = n_classes
self.add_sequence_dim = add_sequence_dim
def forward(self, c):
c = self.embedding(c)
if self.add_sequence_dim:
c = c[:, None, :]
return c
def get_unconditional_conditioning(self, bs, device="cuda"):
uc_class = (
self.n_classes - 1
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc = torch.ones((bs,), device=device) * uc_class
uc = {self.key: uc.long()}
return uc
class ClassEmbedderForMultiCond(ClassEmbedder):
def forward(self, batch, key=None, disable_dropout=False):
out = batch
key = default(key, self.key)
islist = isinstance(batch[key], list)
if islist:
batch[key] = batch[key][0]
c_out = super().forward(batch, key, disable_dropout)
out[key] = [c_out] if islist else c_out
return out
class FrozenT5Embedder(AbstractEmbModel):
"""Uses the T5 transformer encoder for text"""
def __init__(
self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
# @autocast
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
with torch.autocast("cuda", enabled=False):
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenByT5Embedder(AbstractEmbModel):
"""
Uses the ByT5 transformer encoder for text. Is character-aware.
"""
def __init__(
self, version="google/byt5-base", device="cuda", max_length=77, freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = ByT5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
with torch.autocast("cuda", enabled=False):
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenCLIPEmbedder(AbstractEmbModel):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
max_length=77,
freeze=True,
layer="last",
layer_idx=None,
always_return_pooled=False,
): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
self.return_pooled = always_return_pooled
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
@autocast
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(
input_ids=tokens, output_hidden_states=self.layer == "hidden"
)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
if self.return_pooled:
return z, outputs.pooler_output
return z
def encode(self, text):
return self(text)
class FrozenOpenCLIPEmbedder2(AbstractEmbModel):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = ["pooled", "last", "penultimate"]
def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
max_length=77,
freeze=True,
layer="last",
always_return_pooled=False,
legacy=True,
):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device("cpu"),
pretrained=version,
)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
self.return_pooled = always_return_pooled
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
self.legacy = legacy
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
@autocast
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
if not self.return_pooled and self.legacy:
return z
if self.return_pooled:
assert not self.legacy
return z[self.layer], z["pooled"]
return z[self.layer]
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
if self.legacy:
x = x[self.layer]
x = self.model.ln_final(x)
return x
else:
# x is a dict and will stay a dict
o = x["last"]
o = self.model.ln_final(o)
pooled = self.pool(o, text)
x["pooled"] = pooled
return x
def pool(self, x, text):
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = (
x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
@ self.model.text_projection
)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
outputs = {}
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - 1:
outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD
if (
self.model.transformer.grad_checkpointing
and not torch.jit.is_scripting()
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
outputs["last"] = x.permute(1, 0, 2) # LND -> NLD
return outputs
def encode(self, text):
return self(text)
class FrozenOpenCLIPEmbedder(AbstractEmbModel):
LAYERS = [
# "pooled",
"last",
"penultimate",
]
def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
max_length=77,
freeze=True,
layer="last",
):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch, device=torch.device("cpu"), pretrained=version
)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if (
self.model.transformer.grad_checkpointing
and not torch.jit.is_scripting()
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)
class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
max_length=77,
freeze=True,
antialias=True,
ucg_rate=0.0,
unsqueeze_dim=False,
repeat_to_max_len=False,
num_image_crops=0,
output_tokens=False,
):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device("cpu"),
pretrained=version,
)
del model.transformer
self.model = model
self.max_crops = num_image_crops
self.pad_to_max_len = self.max_crops > 0
self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.antialias = antialias
self.register_buffer(
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
)
self.register_buffer(
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
)
self.ucg_rate = ucg_rate
self.unsqueeze_dim = unsqueeze_dim
self.stored_batch = None
self.model.visual.output_tokens = output_tokens
self.output_tokens = output_tokens
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(
x,
(224, 224),
interpolation="bicubic",
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
@autocast
def forward(self, image, no_dropout=False):
z = self.encode_with_vision_transformer(image)
tokens = None
if self.output_tokens:
z, tokens = z[0], z[1]
z = z.to(image.dtype)
if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
z = (
torch.bernoulli(
(1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
)[:, None]
* z
)
if tokens is not None:
tokens = (
expand_dims_like(
torch.bernoulli(
(1.0 - self.ucg_rate)
* torch.ones(tokens.shape[0], device=tokens.device)
),
tokens,
)
* tokens
)
if self.unsqueeze_dim:
z = z[:, None, :]
if self.output_tokens:
assert not self.repeat_to_max_len
assert not self.pad_to_max_len
return tokens, z
if self.repeat_to_max_len:
if z.dim() == 2:
z_ = z[:, None, :]
else:
z_ = z
return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
elif self.pad_to_max_len:
assert z.dim() == 3
z_pad = torch.cat(
(
z,
torch.zeros(
z.shape[0],
self.max_length - z.shape[1],
z.shape[2],
device=z.device,
),
),
1,
)
return z_pad, z_pad[:, 0, ...]
return z
def encode_with_vision_transformer(self, img):
# if self.max_crops > 0:
# img = self.preprocess_by_cropping(img)
if img.dim() == 5:
assert self.max_crops == img.shape[1]
img = rearrange(img, "b n c h w -> (b n) c h w")
img = self.preprocess(img)
if not self.output_tokens:
assert not self.model.visual.output_tokens
x = self.model.visual(img)
tokens = None
else:
assert self.model.visual.output_tokens
x, tokens = self.model.visual(img)
if self.max_crops > 0:
x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
# drop out between 0 and all along the sequence axis
x = (
torch.bernoulli(
(1.0 - self.ucg_rate)
* torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
)
* x
)
if tokens is not None:
tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
print(
f"You are running very experimental token-concat in {self.__class__.__name__}. "
f"Check what you are doing, and then remove this message."
)
if self.output_tokens:
return x, tokens
return x
def encode(self, text):
return self(text)
class FrozenCLIPT5Encoder(AbstractEmbModel):
def __init__(
self,
clip_version="openai/clip-vit-large-patch14",
t5_version="google/t5-v1_1-xl",
device="cuda",
clip_max_length=77,
t5_max_length=77,
):
super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(
clip_version, device, max_length=clip_max_length
)
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
print(
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
)
def encode(self, text):
return self(text)
def forward(self, text):
clip_z = self.clip_encoder.encode(text)
t5_z = self.t5_encoder.encode(text)
return [clip_z, t5_z]
class SpatialRescaler(nn.Module):
def __init__(
self,
n_stages=1,
method="bilinear",
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False,
wrap_video=False,
kernel_size=1,
remap_output=False,
):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in [
"nearest",
"linear",
"bilinear",
"trilinear",
"bicubic",
"area",
]
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None or remap_output
if self.remap_output:
print(
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
)
self.channel_mapper = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
bias=bias,
padding=kernel_size // 2,
)
self.wrap_video = wrap_video
def forward(self, x):
if self.wrap_video and x.ndim == 5:
B, C, T, H, W = x.shape
x = rearrange(x, "b c t h w -> b t c h w")
x = rearrange(x, "b t c h w -> (b t) c h w")
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
if self.wrap_video:
x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C)
x = rearrange(x, "b t c h w -> b c t h w")
if self.remap_output:
x = self.channel_mapper(x)
return x
def encode(self, x):
return self(x)
class LowScaleEncoder(nn.Module):
def __init__(
self,
model_config,
linear_start,
linear_end,
timesteps=1000,
max_noise_level=250,
output_size=64,
scale_factor=1.0,
):
super().__init__()
self.max_noise_level = max_noise_level
self.model = instantiate_from_config(model_config)
self.augmentation_schedule = self.register_schedule(
timesteps=timesteps, linear_start=linear_start, linear_end=linear_end
)
self.out_size = output_size
self.scale_factor = scale_factor
def register_schedule(
self,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert (
alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep"
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer("betas", to_torch(betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer(
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def forward(self, x):
z = self.model.encode(x)
if isinstance(z, DiagonalGaussianDistribution):
z = z.sample()
z = z * self.scale_factor
noise_level = torch.randint(
0, self.max_noise_level, (x.shape[0],), device=x.device
).long()
z = self.q_sample(z, noise_level)
if self.out_size is not None:
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
return z, noise_level
def decode(self, z):
z = z / self.scale_factor
return self.model.decode(z)
class ConcatTimestepEmbedderND(AbstractEmbModel):
"""embeds each dimension independently and concatenates them"""
def __init__(self, outdim):
super().__init__()
self.timestep = Timestep(outdim)
self.outdim = outdim
def forward(self, x):
if x.ndim == 1:
x = x[:, None]
assert len(x.shape) == 2
b, dims = x.shape[0], x.shape[1]
x = rearrange(x, "b d -> (b d)")
emb = self.timestep(x)
emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
return emb
class GaussianEncoder(Encoder, AbstractEmbModel):
def __init__(
self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.posterior = DiagonalGaussianRegularizer()
self.weight = weight
self.flatten_output = flatten_output
def forward(self, x) -> Tuple[Dict, torch.Tensor]:
z = super().forward(x)
z, log = self.posterior(z)
log["loss"] = log["kl_loss"]
log["weight"] = self.weight
if self.flatten_output:
z = rearrange(z, "b c h w -> b (h w ) c")
return log, z

231
sgm/util.py Normal file
View File

@@ -0,0 +1,231 @@
import functools
import importlib
import os
from functools import partial
from inspect import isfunction
import fsspec
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from safetensors.torch import load_file as load_safetensors
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def get_string_from_tuple(s):
try:
# Check if the string starts and ends with parentheses
if s[0] == "(" and s[-1] == ")":
# Convert the string to a tuple
t = eval(s)
# Check if the type of t is tuple
if type(t) == tuple:
return t[0]
else:
pass
except:
pass
return s
def is_power_of_two(n):
"""
chat.openai.com/chat
Return True if n is a power of 2, otherwise return False.
The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
"""
if n <= 0:
return False
return (n & (n - 1)) == 0
def autocast(f, enabled=True):
def do_autocast(*args, **kwargs):
with torch.cuda.amp.autocast(
enabled=enabled,
dtype=torch.get_autocast_gpu_dtype(),
cache_enabled=torch.is_autocast_cache_enabled(),
):
return f(*args, **kwargs)
return do_autocast
def load_partial_from_config(config):
return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
nc = int(40 * (wh[0] / 256))
if isinstance(xc[bi], list):
text_seq = xc[bi][0]
else:
text_seq = xc[bi]
lines = "\n".join(
text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
)
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def partialclass(cls, *args, **kwargs):
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
return NewCls
def make_path_absolute(path):
fs, p = fsspec.core.url_to_fs(path)
if fs.protocol == "file":
return os.path.abspath(p)
return path
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def isheatmap(x):
if not isinstance(x, torch.Tensor):
return False
return x.ndim == 2
def isneighbors(x):
if not isinstance(x, torch.Tensor):
return False
return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
def exists(x):
return x is not None
def expand_dims_like(x, y):
while x.dim() != y.dim():
x = x.unsqueeze(-1)
return x
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params
def instantiate_from_config(config):
if not "target" in config:
if config == "__is_first_stage__":
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False, invalidate_cache=True):
module, cls = string.rsplit(".", 1)
if invalidate_cache:
importlib.invalidate_caches()
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
return x[(...,) + (None,) * dims_to_append]
def load_model_from_config(config, ckpt, verbose=True, freeze=True):
print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
else:
raise NotImplementedError
model = instantiate_from_config(config.model)
sd = pl_sd["state_dict"]
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
if freeze:
for param in model.parameters():
param.requires_grad = False
model.eval()
return model