mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-05 05:44:29 +01:00
Compare commits
42 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
45c443b316 | ||
|
|
dea60596fc | ||
|
|
299abbcd90 | ||
|
|
e5d714d304 | ||
|
|
f2fa96b7e5 | ||
|
|
c60c091f4d | ||
|
|
931d7a389a | ||
|
|
e596332148 | ||
|
|
4a3f0f546e | ||
|
|
7934245835 | ||
|
|
1da250906d | ||
|
|
a4ceca6d03 | ||
|
|
68f3f89bd3 | ||
|
|
57862fb4c7 | ||
|
|
ef520df1db | ||
|
|
2897fdc99a | ||
|
|
b5b5680150 | ||
|
|
6f6d3f8716 | ||
|
|
6ecd0a900a | ||
|
|
e25e4c0df1 | ||
|
|
e5dc9669ed | ||
|
|
48904a692d | ||
|
|
5c10deee76 | ||
|
|
89f5413e6d | ||
|
|
ba3e7fed5a | ||
|
|
ea89ce793d | ||
|
|
7b1978e055 | ||
|
|
9d5ace911e | ||
|
|
95b9acc5c6 | ||
|
|
5df4d9893c | ||
|
|
ae18ba3e87 | ||
|
|
061d11d55d | ||
|
|
2796c81a5f | ||
|
|
e9869d7822 | ||
|
|
613af104c6 | ||
|
|
376cec3b0f | ||
|
|
76e549dd94 | ||
|
|
5f0a2fcf48 | ||
|
|
d8a6a97fb0 | ||
|
|
a1af4ac4f1 | ||
|
|
58ddbee3ee | ||
|
|
bec98beff8 |
15
.github/workflows/black.yml
vendored
Normal file
15
.github/workflows/black.yml
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
name: Run black
|
||||
on: [pull_request]
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Install venv
|
||||
run: |
|
||||
sudo apt-get -y install python3.10-venv
|
||||
- uses: psf/black@stable
|
||||
with:
|
||||
options: "--check --verbose -l88"
|
||||
src: "./sgm ./scripts ./main.py"
|
||||
27
.github/workflows/test-build.yaml
vendored
Normal file
27
.github/workflows/test-build.yaml
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
name: Build package
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.8", "3.10"]
|
||||
requirements-file: ["pt2", "pt13"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements/${{ matrix.requirements-file }}.txt
|
||||
pip install .
|
||||
34
.github/workflows/test-inference.yml
vendored
Normal file
34
.github/workflows/test-inference.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: Test inference
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: "Test inference"
|
||||
# This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment
|
||||
if: github.repository == 'stability-ai/generative-models'
|
||||
runs-on: [self-hosted, slurm, g40]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: "Symlink checkpoints"
|
||||
run: ln -s ${{vars.SGM_CHECKPOINTS_PATH}} checkpoints
|
||||
- name: "Setup python"
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: "Install Hatch"
|
||||
run: pip install hatch
|
||||
- name: "Run inference tests"
|
||||
run: hatch run ci:test-inference --junit-xml test-results.xml
|
||||
- name: Surface failing tests
|
||||
if: always()
|
||||
uses: pmeier/pytest-results-action@main
|
||||
with:
|
||||
path: test-results.xml
|
||||
summary: true
|
||||
display-options: fEX
|
||||
fail-on-empty: true
|
||||
17
.gitignore
vendored
17
.gitignore
vendored
@@ -1,7 +1,14 @@
|
||||
.pt2
|
||||
.pt2_2
|
||||
.pt13
|
||||
# extensions
|
||||
*.egg-info
|
||||
build
|
||||
*.py[cod]
|
||||
|
||||
# envs
|
||||
.pt13
|
||||
.pt2
|
||||
|
||||
# directories
|
||||
/checkpoints
|
||||
/dist
|
||||
/outputs
|
||||
/checkpoints
|
||||
/build
|
||||
/src
|
||||
1
CODEOWNERS
Normal file
1
CODEOWNERS
Normal file
@@ -0,0 +1 @@
|
||||
.github @Stability-AI/infrastructure
|
||||
21
LICENSE-CODE
Normal file
21
LICENSE-CODE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Stability AI
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
120
README.md
120
README.md
@@ -4,14 +4,29 @@
|
||||
|
||||
## News
|
||||
|
||||
**July 26, 2023**
|
||||
- We are releasing two new open models with a permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file hashes):
|
||||
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version over `SDXL-base-0.9`.
|
||||
- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version over `SDXL-refiner-0.9`.
|
||||
|
||||

|
||||
|
||||
|
||||
**July 4, 2023**
|
||||
- A technical report on SDXL is now available [here](https://arxiv.org/abs/2307.01952).
|
||||
|
||||
**June 22, 2023**
|
||||
|
||||
|
||||
- We are releasing two new diffusion models:
|
||||
- `SD-XL 0.9-base`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip) and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses the OpenCLIP model.
|
||||
- `SD-XL 0.9-refiner`: The refiner has been trained to denoise small noise levels of high quality data and as such is not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
|
||||
- We are releasing two new diffusion models for research purposes:
|
||||
- `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip) and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses the OpenCLIP model.
|
||||
- `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
|
||||
|
||||
**We plan to do a full release soon (July).**
|
||||
If you would like to access these models for your research, please apply using one of the following links:
|
||||
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
|
||||
This means that you can apply for any of the two links - and if you are granted - you can access both.
|
||||
Please log in to your Hugging Face Account with your organization email to request access.
|
||||
**We plan to do a full release soon (July).**
|
||||
|
||||
## The codebase
|
||||
|
||||
@@ -21,7 +36,7 @@ Modularity is king. This repo implements a config-driven approach where we build
|
||||
|
||||
### Changelog from the old `ldm` codebase
|
||||
|
||||
For training, we use [pytorch-lightning](https://www.pytorchlightning.ai/index.html), but it should be easy to use other training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`, now `DiffusionEngine`) has been cleaned up:
|
||||
For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`, now `DiffusionEngine`) has been cleaned up:
|
||||
|
||||
- No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial conditionings, and all combinations thereof) in a single class: `GeneralConditioner`, see `sgm/modules/encoders/modules.py`.
|
||||
- We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
|
||||
@@ -43,47 +58,98 @@ cd generative-models
|
||||
|
||||
#### 2. Setting up the virtualenv
|
||||
|
||||
This is assuming you have navigated to the `generative-models` root after cloning it.
|
||||
This is assuming you have navigated to the `generative-models` root after cloning it.
|
||||
|
||||
**NOTE:** This is tested under `python3.8` and `python3.10`. For other python versions, you might encounter version conflicts.
|
||||
|
||||
|
||||
**PyTorch 1.13**
|
||||
**PyTorch 1.13**
|
||||
|
||||
```shell
|
||||
# install required packages from pypi
|
||||
python3 -m venv .pt1
|
||||
source .pt1/bin/activate
|
||||
pip3 install wheel
|
||||
pip3 install -r requirements_pt13.txt
|
||||
python3 -m venv .pt13
|
||||
source .pt13/bin/activate
|
||||
pip3 install -r requirements/pt13.txt
|
||||
```
|
||||
|
||||
**PyTorch 2.0**
|
||||
**PyTorch 2.0**
|
||||
|
||||
|
||||
```shell
|
||||
# install required packages from pypi
|
||||
python3 -m venv .pt2
|
||||
source .pt2/bin/activate
|
||||
pip3 install wheel
|
||||
pip3 install -r requirements_pt2.txt
|
||||
pip3 install -r requirements/pt2.txt
|
||||
```
|
||||
|
||||
## Inference:
|
||||
|
||||
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`. The following models are currently supported:
|
||||
- [SD-XL 0.9-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)
|
||||
- [SD-XL 0.9-refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
|
||||
- [SD 2.1-512](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.safetensors)
|
||||
- [SD 2.1-768](https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors)
|
||||
#### 3. Install `sgm`
|
||||
|
||||
```shell
|
||||
pip3 install .
|
||||
```
|
||||
|
||||
#### 4. Install `sdata` for training
|
||||
|
||||
```shell
|
||||
pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
|
||||
```
|
||||
|
||||
## Packaging
|
||||
|
||||
This repository uses PEP 517 compliant packaging using [Hatch](https://hatch.pypa.io/latest/).
|
||||
|
||||
To build a distributable wheel, install `hatch` and run `hatch build`
|
||||
(specifying `-t wheel` will skip building a sdist, which is not necessary).
|
||||
|
||||
```
|
||||
pip install hatch
|
||||
hatch build -t wheel
|
||||
```
|
||||
|
||||
You will find the built package in `dist/`. You can install the wheel with `pip install dist/*.whl`.
|
||||
|
||||
Note that the package does **not** currently specify dependencies; you will need to install the required packages,
|
||||
depending on your use case and PyTorch version, manually.
|
||||
|
||||
## Inference
|
||||
|
||||
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`.
|
||||
We provide file hashes for the complete file as well as for only the saved tensors in the file (see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
|
||||
The following models are currently supported:
|
||||
|
||||
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
```
|
||||
File Hash (sha256): 31e35c80fc4829d14f90153f4c74cd59c90b779f6afe05a74cd6120b893f7e5b
|
||||
Tensordata Hash (sha256): 0xd7a9105a900fd52748f20725fe52fe52b507fd36bee4fc107b1550a26e6ee1d7
|
||||
```
|
||||
- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)
|
||||
```
|
||||
File Hash (sha256): 7440042bbdc8a24813002c09b6b69b64dc90fded4472613437b7f55f9b7d9c5f
|
||||
Tensordata Hash (sha256): 0x1a77d21bebc4b4de78c474a90cb74dc0d2217caf4061971dbfa75ad406b75d81
|
||||
```
|
||||
- [SDXL-base-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)
|
||||
- [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
|
||||
- [SD-2.1-512](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.safetensors)
|
||||
- [SD-2.1-768](https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors)
|
||||
|
||||
**Weights for SDXL**:
|
||||
If you would like to access these models for your research, please apply using one of the following links:
|
||||
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
|
||||
This means that you can apply for any of the two links - and if you are granted - you can access both.
|
||||
Please log in to your HuggingFace Account with your organization email to request access.
|
||||
|
||||
After obtaining the weights, place them into `checkpoints/`.
|
||||
**SDXL-1.0:**
|
||||
The weights of SDXL-1.0 are available (subject to a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here:
|
||||
- base model: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/
|
||||
- refiner model: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/
|
||||
|
||||
|
||||
**SDXL-0.9:**
|
||||
The weights of SDXL-0.9 are available and subject to a [research license](model_licenses/LICENSE-SDXL0.9).
|
||||
If you would like to access these models for your research, please apply using one of the following links:
|
||||
[SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
|
||||
This means that you can apply for any of the two links - and if you are granted - you can access both.
|
||||
Please log in to your Hugging Face Account with your organization email to request access.
|
||||
|
||||
|
||||
After obtaining the weights, place them into `checkpoints/`.
|
||||
Next, start the demo using
|
||||
|
||||
```
|
||||
@@ -137,7 +203,7 @@ run
|
||||
python main.py --base configs/example_training/toy/mnist_cond.yaml
|
||||
```
|
||||
|
||||
**NOTE 1:** Using the non-toy-dataset configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml` and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depdending on the used dataset (which is expected to stored in tar-file in the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search for comments containing `USER:` in the respective config.
|
||||
**NOTE 1:** Using the non-toy-dataset configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml` and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the used dataset (which is expected to stored in tar-file in the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search for comments containing `USER:` in the respective config.
|
||||
|
||||
**NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`, only `pytorch1.13` is supported.
|
||||
|
||||
@@ -174,7 +240,7 @@ guidance.
|
||||
### Dataset Handling
|
||||
|
||||
|
||||
For large scale training we recommend using the datapipelines from our [datapipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement and automatically included when following the steps from the [Installation section](#installation).
|
||||
For large scale training we recommend using the data pipelines from our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement and automatically included when following the steps from the [Installation section](#installation).
|
||||
Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
|
||||
data keys/values,
|
||||
e.g.,
|
||||
|
||||
BIN
assets/001_with_eval.png
Normal file
BIN
assets/001_with_eval.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.0 MiB |
14
main.py
14
main.py
@@ -12,22 +12,18 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchvision
|
||||
import wandb
|
||||
from PIL import Image
|
||||
from matplotlib import pyplot as plt
|
||||
from natsort import natsorted
|
||||
from omegaconf import OmegaConf
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from sgm.util import (
|
||||
exists,
|
||||
instantiate_from_config,
|
||||
isheatmap,
|
||||
)
|
||||
from sgm.util import exists, instantiate_from_config, isheatmap
|
||||
|
||||
MULTINODE_HACKS = True
|
||||
|
||||
@@ -469,9 +465,8 @@ class ImageLogger(Callback):
|
||||
self.log_img(pl_module, batch, batch_idx, split="train")
|
||||
|
||||
@rank_zero_only
|
||||
# def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
def on_validation_batch_end(
|
||||
self, trainer, pl_module, outputs, batch, batch_idx, **kwargs
|
||||
self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs
|
||||
):
|
||||
if not self.disabled and pl_module.global_step > 0:
|
||||
self.log_img(pl_module, batch, batch_idx, split="val")
|
||||
@@ -911,11 +906,12 @@ if __name__ == "__main__":
|
||||
trainer.test(model, data)
|
||||
except RuntimeError as err:
|
||||
if MULTINODE_HACKS:
|
||||
import requests
|
||||
import datetime
|
||||
import os
|
||||
import socket
|
||||
|
||||
import requests
|
||||
|
||||
device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
|
||||
hostname = socket.gethostname()
|
||||
ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
175
model_licenses/LICENSE-SDXL1.0
Normal file
175
model_licenses/LICENSE-SDXL1.0
Normal file
@@ -0,0 +1,175 @@
|
||||
Copyright (c) 2023 Stability AI CreativeML Open RAIL++-M License dated July 26, 2023
|
||||
|
||||
Section I: PREAMBLE Multimodal generative models are being widely adopted and used, and
|
||||
have the potential to transform the way artists, among other individuals, conceive and
|
||||
benefit from AI or ML technologies as a tool for content creation. Notwithstanding the
|
||||
current and potential benefits that these artifacts can bring to society at large, there
|
||||
are also concerns about potential misuses of them, either due to their technical
|
||||
limitations or ethical considerations. In short, this license strives for both the open
|
||||
and responsible downstream use of the accompanying model. When it comes to the open
|
||||
character, we took inspiration from open source permissive licenses regarding the grant
|
||||
of IP rights. Referring to the downstream responsible use, we added use-based
|
||||
restrictions not permitting the use of the model in very specific scenarios, in order
|
||||
for the licensor to be able to enforce the license in case potential misuses of the
|
||||
Model may occur. At the same time, we strive to promote open and responsible research on
|
||||
generative models for art and content generation. Even though downstream derivative
|
||||
versions of the model could be released under different licensing terms, the latter will
|
||||
always have to include - at minimum - the same use-based restrictions as the ones in the
|
||||
original license (this license). We believe in the intersection between open and
|
||||
responsible AI development; thus, this agreement aims to strike a balance between both
|
||||
in order to enable responsible open-science in the field of AI. This CreativeML Open
|
||||
RAIL++-M License governs the use of the model (and its derivatives) and is informed by
|
||||
the model card associated with the model. NOW THEREFORE, You and Licensor agree as
|
||||
follows: Definitions "License" means the terms and conditions for use, reproduction, and
|
||||
Distribution as defined in this document. "Data" means a collection of information
|
||||
and/or content extracted from the dataset used with the Model, including to train,
|
||||
pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
|
||||
"Output" means the results of operating a Model as embodied in informational content
|
||||
resulting therefrom. "Model" means any accompanying machine-learning based assemblies
|
||||
(including checkpoints), consisting of learnt weights, parameters (including optimizer
|
||||
states), corresponding to the model architecture as embodied in the Complementary
|
||||
Material, that have been trained or tuned, in whole or in part on the Data, using the
|
||||
Complementary Material. "Derivatives of the Model" means all modifications to the Model,
|
||||
works based on the Model, or any other model which is created or initialized by transfer
|
||||
of patterns of the weights, parameters, activations or output of the Model, to the other
|
||||
model, in order to cause the other model to perform similarly to the Model, including -
|
||||
but not limited to - distillation methods entailing the use of intermediate data
|
||||
representations or methods based on the generation of synthetic data by the Model for
|
||||
training the other model. "Complementary Material" means the accompanying source code
|
||||
and scripts used to define, run, load, benchmark or evaluate the Model, and used to
|
||||
prepare data for training or evaluation, if any. This includes any accompanying
|
||||
documentation, tutorials, examples, etc, if any. "Distribution" means any transmission,
|
||||
reproduction, publication or other sharing of the Model or Derivatives of the Model to a
|
||||
third party, including providing the Model as a hosted service made available by
|
||||
electronic or other remote means - e.g. API-based or web access. "Licensor" means the
|
||||
copyright owner or entity authorized by the copyright owner that is granting the
|
||||
License, including the persons or entities that may have rights in the Model and/or
|
||||
distributing the Model. "You" (or "Your") means an individual or Legal Entity exercising
|
||||
permissions granted by this License and/or making use of the Model for whichever purpose
|
||||
and in any field of use, including usage of the Model in an end-use application - e.g.
|
||||
chatbot, translator, image generator. "Third Parties" means individuals or legal
|
||||
entities that are not under common control with Licensor or You. "Contribution" means
|
||||
any work of authorship, including the original version of the Model and any
|
||||
modifications or additions to that Model or Derivatives of the Model thereof, that is
|
||||
intentionally submitted to Licensor for inclusion in the Model by the copyright owner or
|
||||
by an individual or Legal Entity authorized to submit on behalf of the copyright owner.
|
||||
For the purposes of this definition, "submitted" means any form of electronic, verbal,
|
||||
or written communication sent to the Licensor or its representatives, including but not
|
||||
limited to communication on electronic mailing lists, source code control systems, and
|
||||
issue tracking systems that are managed by, or on behalf of, the Licensor for the
|
||||
purpose of discussing and improving the Model, but excluding communication that is
|
||||
conspicuously marked or otherwise designated in writing by the copyright owner as "Not a
|
||||
Contribution." "Contributor" means Licensor and any individual or Legal Entity on behalf
|
||||
of whom a Contribution has been received by Licensor and subsequently incorporated
|
||||
within the Model.
|
||||
|
||||
Section II: INTELLECTUAL PROPERTY RIGHTS Both copyright and patent grants apply to the
|
||||
Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of
|
||||
the Model are subject to additional terms as described in
|
||||
|
||||
Section III. Grant of Copyright License. Subject to the terms and conditions of this
|
||||
License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive,
|
||||
no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly
|
||||
display, publicly perform, sublicense, and distribute the Complementary Material, the
|
||||
Model, and Derivatives of the Model. Grant of Patent License. Subject to the terms and
|
||||
conditions of this License and where and as applicable, each Contributor hereby grants
|
||||
to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this paragraph) patent license to make, have made, use, offer to
|
||||
sell, sell, import, and otherwise transfer the Model and the Complementary Material,
|
||||
where such license applies only to those patent claims licensable by such Contributor
|
||||
that are necessarily infringed by their Contribution(s) alone or by combination of their
|
||||
Contribution(s) with the Model to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a cross-claim or counterclaim
|
||||
in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution
|
||||
incorporated within the Model and/or Complementary Material constitutes direct or
|
||||
contributory patent infringement, then any patent licenses granted to You under this
|
||||
License for the Model and/or Work shall terminate as of the date such litigation is
|
||||
asserted or filed. Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
|
||||
Distribution and Redistribution. You may host for Third Party remote access purposes
|
||||
(e.g. software-as-a-service), reproduce and distribute copies of the Model or
|
||||
Derivatives of the Model thereof in any medium, with or without modifications, provided
|
||||
that You meet the following conditions: Use-based restrictions as referenced in
|
||||
paragraph 5 MUST be included as an enforceable provision by You in any type of legal
|
||||
agreement (e.g. a license) governing the use and/or distribution of the Model or
|
||||
Derivatives of the Model, and You shall give notice to subsequent users You Distribute
|
||||
to, that the Model or Derivatives of the Model are subject to paragraph 5. This
|
||||
provision does not apply to the use of Complementary Material. You must give any Third
|
||||
Party recipients of the Model or Derivatives of the Model a copy of this License; You
|
||||
must cause any modified files to carry prominent notices stating that You changed the
|
||||
files; You must retain all copyright, patent, trademark, and attribution notices
|
||||
excluding those notices that do not pertain to any part of the Model, Derivatives of the
|
||||
Model. You may add Your own copyright statement to Your modifications and may provide
|
||||
additional or different license terms and conditions - respecting paragraph 4.a. - for
|
||||
use, reproduction, or Distribution of Your modifications, or for any such Derivatives of
|
||||
the Model as a whole, provided Your use, reproduction, and Distribution of the Model
|
||||
otherwise complies with the conditions stated in this License. Use-based restrictions.
|
||||
The restrictions set forth in Attachment A are considered Use-based restrictions.
|
||||
Therefore You cannot use the Model and the Derivatives of the Model for the specified
|
||||
restricted uses. You may use the Model subject to this License, including only for
|
||||
lawful purposes and in accordance with the License. Use may include creating any content
|
||||
with, finetuning, updating, running, training, evaluating and/or reparametrizing the
|
||||
Model. You shall require all of Your users who use the Model or a Derivative of the
|
||||
Model to comply with the terms of this paragraph (paragraph 5). The Output You Generate.
|
||||
Except as set forth herein, Licensor claims no rights in the Output You generate using
|
||||
the Model. You are accountable for the Output you generate and its subsequent uses. No
|
||||
use of the output can contravene any provision as stated in the License.
|
||||
|
||||
Section IV: OTHER PROVISIONS Updates and Runtime Restrictions. To the maximum extent
|
||||
permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage
|
||||
of the Model in violation of this License. Trademarks and related. Nothing in this
|
||||
License permits You to make use of Licensors’ trademarks, trade names, logos or to
|
||||
otherwise suggest endorsement or misrepresent the relationship between the parties; and
|
||||
any rights not expressly granted herein are reserved by the Licensors. Disclaimer of
|
||||
Warranty. Unless required by applicable law or agreed to in writing, Licensor provides
|
||||
the Model and the Complementary Material (and each Contributor provides its
|
||||
Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
express or implied, including, without limitation, any warranties or conditions of
|
||||
TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are
|
||||
solely responsible for determining the appropriateness of using or redistributing the
|
||||
Model, Derivatives of the Model, and the Complementary Material and assume any risks
|
||||
associated with Your exercise of permissions under this License. Limitation of
|
||||
Liability. In no event and under no legal theory, whether in tort (including
|
||||
negligence), contract, or otherwise, unless required by applicable law (such as
|
||||
deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special, incidental, or
|
||||
consequential damages of any character arising as a result of this License or out of the
|
||||
use or inability to use the Model and the Complementary Material (including but not
|
||||
limited to damages for loss of goodwill, work stoppage, computer failure or malfunction,
|
||||
or any and all other commercial damages or losses), even if such Contributor has been
|
||||
advised of the possibility of such damages. Accepting Warranty or Additional Liability.
|
||||
While redistributing the Model, Derivatives of the Model and the Complementary Material
|
||||
thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty,
|
||||
indemnity, or other liability obligations and/or rights consistent with this License.
|
||||
However, in accepting such obligations, You may act only on Your own behalf and on Your
|
||||
sole responsibility, not on behalf of any other Contributor, and only if You agree to
|
||||
indemnify, defend, and hold each Contributor harmless for any liability incurred by, or
|
||||
claims asserted against, such Contributor by reason of your accepting any such warranty
|
||||
or additional liability. If any provision of this License is held to be invalid, illegal
|
||||
or unenforceable, the remaining provisions shall be unaffected thereby and remain valid
|
||||
as if such provision had not been set forth herein.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
Attachment A Use Restrictions
|
||||
You agree not to use the Model or Derivatives of the Model:
|
||||
In any way that violates any applicable national, federal, state, local or
|
||||
international law or regulation; For the purpose of exploiting, harming or attempting to
|
||||
exploit or harm minors in any way; To generate or disseminate verifiably false
|
||||
information and/or content with the purpose of harming others; To generate or
|
||||
disseminate personal identifiable information that can be used to harm an individual; To
|
||||
defame, disparage or otherwise harass others; For fully automated decision making that
|
||||
adversely impacts an individual’s legal rights or otherwise creates or modifies a
|
||||
binding, enforceable obligation; For any use intended to or which has the effect of
|
||||
discriminating against or harming individuals or groups based on online or offline
|
||||
social behavior or known or predicted personal or personality characteristics; To
|
||||
exploit any of the vulnerabilities of a specific group of persons based on their age,
|
||||
social, physical or mental characteristics, in order to materially distort the behavior
|
||||
of a person pertaining to that group in a manner that causes or is likely to cause that
|
||||
person or another person physical or psychological harm; For any use intended to or
|
||||
which has the effect of discriminating against individuals or groups based on legally
|
||||
protected characteristics or categories; To provide medical advice and medical results
|
||||
interpretation; To generate or disseminate information for the purpose to be used for
|
||||
administration of justice, law enforcement, immigration or asylum processes, such as
|
||||
predicting an individual will commit fraud/crime commitment (e.g. by text profiling,
|
||||
drawing causal relationships between assertions made in documents, indiscriminate and
|
||||
arbitrarily-targeted use).
|
||||
48
pyproject.toml
Normal file
48
pyproject.toml
Normal file
@@ -0,0 +1,48 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "sgm"
|
||||
dynamic = ["version"]
|
||||
description = "Stability Generative Models"
|
||||
readme = "README.md"
|
||||
license-files = { paths = ["LICENSE-CODE"] }
|
||||
requires-python = ">=3.8"
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/Stability-AI/generative-models"
|
||||
|
||||
[tool.hatch.version]
|
||||
path = "sgm/__init__.py"
|
||||
|
||||
[tool.hatch.build]
|
||||
# This needs to be explicitly set so the configuration files
|
||||
# grafted into the `sgm` directory get included in the wheel's
|
||||
# RECORD file.
|
||||
include = [
|
||||
"sgm",
|
||||
]
|
||||
# The force-include configurations below make Hatch copy
|
||||
# the configs/ directory (containing the various YAML files required
|
||||
# to generatively model) into the source distribution and the wheel.
|
||||
|
||||
[tool.hatch.build.targets.sdist.force-include]
|
||||
"./configs" = "sgm/configs"
|
||||
|
||||
[tool.hatch.build.targets.wheel.force-include]
|
||||
"./configs" = "sgm/configs"
|
||||
|
||||
[tool.hatch.envs.ci]
|
||||
skip-install = false
|
||||
|
||||
dependencies = [
|
||||
"pytest"
|
||||
]
|
||||
|
||||
[tool.hatch.envs.ci.scripts]
|
||||
test-inference = [
|
||||
"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
|
||||
"pip install -r requirements/pt2.txt",
|
||||
"pytest -v tests/inference/test_inference.py {args}",
|
||||
]
|
||||
3
pytest.ini
Normal file
3
pytest.ini
Normal file
@@ -0,0 +1,3 @@
|
||||
[pytest]
|
||||
markers =
|
||||
inference: mark as inference test (deselect with '-m "not inference"')
|
||||
40
requirements/pt13.txt
Normal file
40
requirements/pt13.txt
Normal file
@@ -0,0 +1,40 @@
|
||||
black==23.7.0
|
||||
chardet>=5.1.0
|
||||
clip @ git+https://github.com/openai/CLIP.git
|
||||
einops>=0.6.1
|
||||
fairscale>=0.4.13
|
||||
fire>=0.5.0
|
||||
fsspec>=2023.6.0
|
||||
invisible-watermark>=0.2.0
|
||||
kornia==0.6.9
|
||||
matplotlib>=3.7.2
|
||||
natsort>=8.4.0
|
||||
numpy>=1.24.4
|
||||
omegaconf>=2.3.0
|
||||
onnx<=1.12.0
|
||||
open-clip-torch>=2.20.0
|
||||
opencv-python==4.6.0.66
|
||||
pandas>=2.0.3
|
||||
pillow>=9.5.0
|
||||
pudb>=2022.1.3
|
||||
pytorch-lightning==1.8.5
|
||||
pyyaml>=6.0.1
|
||||
scipy>=1.10.1
|
||||
streamlit>=1.25.0
|
||||
tensorboardx==2.5.1
|
||||
timm>=0.9.2
|
||||
tokenizers==0.12.1
|
||||
--extra-index-url https://download.pytorch.org/whl/cu117
|
||||
torch==1.13.1+cu117
|
||||
torchaudio==0.13.1
|
||||
torchdata==0.5.1
|
||||
torchmetrics>=1.0.1
|
||||
torchvision==0.14.1+cu117
|
||||
tqdm>=4.65.0
|
||||
transformers==4.19.1
|
||||
triton==2.0.0.post1
|
||||
urllib3<1.27,>=1.25.4
|
||||
wandb>=0.15.6
|
||||
webdataset>=0.2.33
|
||||
wheel>=0.41.0
|
||||
xformers==0.0.16
|
||||
39
requirements/pt2.txt
Normal file
39
requirements/pt2.txt
Normal file
@@ -0,0 +1,39 @@
|
||||
black==23.7.0
|
||||
chardet==5.1.0
|
||||
clip @ git+https://github.com/openai/CLIP.git
|
||||
einops>=0.6.1
|
||||
fairscale>=0.4.13
|
||||
fire>=0.5.0
|
||||
fsspec>=2023.6.0
|
||||
invisible-watermark>=0.2.0
|
||||
kornia==0.6.9
|
||||
matplotlib>=3.7.2
|
||||
natsort>=8.4.0
|
||||
ninja>=1.11.1
|
||||
numpy>=1.24.4
|
||||
omegaconf>=2.3.0
|
||||
open-clip-torch>=2.20.0
|
||||
opencv-python==4.6.0.66
|
||||
pandas>=2.0.3
|
||||
pillow>=9.5.0
|
||||
pudb>=2022.1.3
|
||||
pytorch-lightning==2.0.1
|
||||
pyyaml>=6.0.1
|
||||
scipy>=1.10.1
|
||||
streamlit>=0.73.1
|
||||
tensorboardx==2.6
|
||||
timm>=0.9.2
|
||||
tokenizers==0.12.1
|
||||
torch>=2.0.1
|
||||
torchaudio>=2.0.2
|
||||
torchdata==0.6.1
|
||||
torchmetrics>=1.0.1
|
||||
torchvision>=0.15.2
|
||||
tqdm>=4.65.0
|
||||
transformers==4.19.1
|
||||
triton==2.0.0
|
||||
urllib3<1.27,>=1.25.4
|
||||
wandb>=0.15.6
|
||||
webdataset>=0.2.33
|
||||
wheel>=0.41.0
|
||||
xformers>=0.0.20
|
||||
@@ -1,41 +0,0 @@
|
||||
omegaconf
|
||||
einops
|
||||
fire
|
||||
tqdm
|
||||
pillow
|
||||
numpy
|
||||
webdataset>=0.2.33
|
||||
--extra-index-url https://download.pytorch.org/whl/cu117
|
||||
torch==1.13.1+cu117
|
||||
xformers==0.0.16
|
||||
torchaudio==0.13.1
|
||||
torchvision==0.14.1+cu117
|
||||
torchmetrics
|
||||
opencv-python==4.6.0.66
|
||||
fairscale
|
||||
pytorch-lightning==1.8.5
|
||||
fsspec
|
||||
kornia==0.6.9
|
||||
matplotlib
|
||||
natsort
|
||||
tensorboardx==2.5.1
|
||||
open-clip-torch
|
||||
chardet
|
||||
scipy
|
||||
pandas
|
||||
pudb
|
||||
pyyaml
|
||||
urllib3<1.27,>=1.25.4
|
||||
streamlit>=0.73.1
|
||||
timm
|
||||
tokenizers==0.12.1
|
||||
torchdata==0.5.1
|
||||
transformers==4.19.1
|
||||
onnx<=1.12.0
|
||||
triton
|
||||
wandb
|
||||
invisible-watermark
|
||||
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||
-e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
-e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
|
||||
-e .
|
||||
@@ -1,41 +0,0 @@
|
||||
omegaconf
|
||||
einops
|
||||
fire
|
||||
tqdm
|
||||
pillow
|
||||
numpy
|
||||
webdataset>=0.2.33
|
||||
ninja
|
||||
torch
|
||||
matplotlib
|
||||
torchaudio>=2.0.2
|
||||
torchmetrics
|
||||
torchvision>=0.15.2
|
||||
opencv-python==4.6.0.66
|
||||
fairscale
|
||||
pytorch-lightning==2.0.1
|
||||
fire
|
||||
fsspec
|
||||
kornia==0.6.9
|
||||
natsort
|
||||
open-clip-torch
|
||||
chardet==5.1.0
|
||||
tensorboardx==2.6
|
||||
pandas
|
||||
pudb
|
||||
pyyaml
|
||||
urllib3<1.27,>=1.25.4
|
||||
scipy
|
||||
streamlit>=0.73.1
|
||||
timm
|
||||
tokenizers==0.12.1
|
||||
transformers==4.19.1
|
||||
triton==2.0.0
|
||||
torchdata==0.6.1
|
||||
wandb
|
||||
invisible-watermark
|
||||
xformers
|
||||
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||
-e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
-e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
|
||||
-e .
|
||||
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
0
scripts/demo/__init__.py
Normal file
0
scripts/demo/__init__.py
Normal file
@@ -83,7 +83,7 @@ class GetWatermarkMatch:
|
||||
def __call__(self, x: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Detects the number of matching bits the predefined watermark with one
|
||||
or multiple images. Images should be in cv2 format, e.g. h x w x c.
|
||||
or multiple images. Images should be in cv2 format, e.g. h x w x c BGR.
|
||||
|
||||
Args:
|
||||
x: ([B], h w, c) in range [0, 255]
|
||||
@@ -94,7 +94,6 @@ class GetWatermarkMatch:
|
||||
squeeze = len(x.shape) == 3
|
||||
if squeeze:
|
||||
x = x[None, ...]
|
||||
x = np.flip(x, axis=-1)
|
||||
|
||||
bs = x.shape[0]
|
||||
detected = np.empty((bs, self.num_bits), dtype=bool)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from scripts.demo.streamlit_helpers import *
|
||||
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||
|
||||
SAVE_PATH = "outputs/demo/txt2img/"
|
||||
|
||||
@@ -34,7 +34,16 @@ SD_XL_BASE_RATIOS = {
|
||||
}
|
||||
|
||||
VERSION2SPECS = {
|
||||
"SD-XL base": {
|
||||
"SDXL-base-1.0": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": False,
|
||||
"config": "configs/inference/sd_xl_base.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
|
||||
},
|
||||
"SDXL-base-0.9": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
@@ -42,9 +51,8 @@ VERSION2SPECS = {
|
||||
"is_legacy": False,
|
||||
"config": "configs/inference/sd_xl_base.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
|
||||
"is_guided": True,
|
||||
},
|
||||
"sd-2.1": {
|
||||
"SD-2.1": {
|
||||
"H": 512,
|
||||
"W": 512,
|
||||
"C": 4,
|
||||
@@ -52,9 +60,8 @@ VERSION2SPECS = {
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_2_1.yaml",
|
||||
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
|
||||
"is_guided": True,
|
||||
},
|
||||
"sd-2.1-768": {
|
||||
"SD-2.1-768": {
|
||||
"H": 768,
|
||||
"W": 768,
|
||||
"C": 4,
|
||||
@@ -63,7 +70,7 @@ VERSION2SPECS = {
|
||||
"config": "configs/inference/sd_2_1_768.yaml",
|
||||
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
|
||||
},
|
||||
"SDXL-Refiner": {
|
||||
"SDXL-refiner-0.9": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
@@ -71,7 +78,15 @@ VERSION2SPECS = {
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_xl_refiner.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
|
||||
"is_guided": True,
|
||||
},
|
||||
"SDXL-refiner-1.0": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_xl_refiner.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -95,18 +110,19 @@ def load_img(display=True, key=None, device="cuda"):
|
||||
|
||||
|
||||
def run_txt2img(
|
||||
state, version, version_dict, is_legacy=False, return_latents=False, filter=None
|
||||
state,
|
||||
version,
|
||||
version_dict,
|
||||
is_legacy=False,
|
||||
return_latents=False,
|
||||
filter=None,
|
||||
stage2strength=None,
|
||||
):
|
||||
if version == "SD-XL base":
|
||||
ratio = st.sidebar.selectbox("Ratio:", list(SD_XL_BASE_RATIOS.keys()), 10)
|
||||
W, H = SD_XL_BASE_RATIOS[ratio]
|
||||
if version.startswith("SDXL-base"):
|
||||
W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
|
||||
else:
|
||||
H = st.sidebar.number_input(
|
||||
"H", value=version_dict["H"], min_value=64, max_value=2048
|
||||
)
|
||||
W = st.sidebar.number_input(
|
||||
"W", value=version_dict["W"], min_value=64, max_value=2048
|
||||
)
|
||||
H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048)
|
||||
W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048)
|
||||
C = version_dict["C"]
|
||||
F = version_dict["f"]
|
||||
|
||||
@@ -122,10 +138,7 @@ def run_txt2img(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
)
|
||||
num_rows, num_cols, sampler = init_sampling(
|
||||
use_identity_guider=not version_dict["is_guided"]
|
||||
)
|
||||
|
||||
sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)
|
||||
num_samples = num_rows * num_cols
|
||||
|
||||
if st.button("Sample"):
|
||||
@@ -147,7 +160,12 @@ def run_txt2img(
|
||||
|
||||
|
||||
def run_img2img(
|
||||
state, version_dict, is_legacy=False, return_latents=False, filter=None
|
||||
state,
|
||||
version_dict,
|
||||
is_legacy=False,
|
||||
return_latents=False,
|
||||
filter=None,
|
||||
stage2strength=None,
|
||||
):
|
||||
img = load_img()
|
||||
if img is None:
|
||||
@@ -163,13 +181,15 @@ def run_img2img(
|
||||
value_dict = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
|
||||
init_dict,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
)
|
||||
strength = st.number_input(
|
||||
"**Img2Img Strength**", value=0.5, min_value=0.0, max_value=1.0
|
||||
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
|
||||
)
|
||||
num_rows, num_cols, sampler = init_sampling(
|
||||
sampler, num_rows, num_cols = init_sampling(
|
||||
img2img_strength=strength,
|
||||
use_identity_guider=not version_dict["is_guided"],
|
||||
stage2strength=stage2strength,
|
||||
)
|
||||
num_samples = num_rows * num_cols
|
||||
|
||||
@@ -195,6 +215,7 @@ def apply_refiner(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
filter=None,
|
||||
finish_denoising=False,
|
||||
):
|
||||
init_dict = {
|
||||
"orig_width": input.shape[3] * 8,
|
||||
@@ -222,6 +243,7 @@ def apply_refiner(
|
||||
num_samples,
|
||||
skip_encode=True,
|
||||
filter=filter,
|
||||
add_noise=not finish_denoising,
|
||||
)
|
||||
|
||||
return samples
|
||||
@@ -234,20 +256,20 @@ if __name__ == "__main__":
|
||||
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
|
||||
st.write("__________________________")
|
||||
|
||||
if version == "SD-XL base":
|
||||
add_pipeline = st.checkbox("Load SDXL-Refiner?", False)
|
||||
set_lowvram_mode(st.checkbox("Low vram mode", True))
|
||||
|
||||
if version.startswith("SDXL-base"):
|
||||
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
|
||||
st.write("__________________________")
|
||||
else:
|
||||
add_pipeline = False
|
||||
|
||||
filter = DeepFloydDataFiltering(verbose=False)
|
||||
|
||||
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
||||
seed_everything(seed)
|
||||
|
||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
|
||||
|
||||
state = init_st(version_dict)
|
||||
state = init_st(version_dict, load_filter=True)
|
||||
if state["msg"]:
|
||||
st.info(state["msg"])
|
||||
model = state["model"]
|
||||
@@ -263,30 +285,34 @@ if __name__ == "__main__":
|
||||
else:
|
||||
negative_prompt = "" # which is unused
|
||||
|
||||
stage2strength = None
|
||||
finish_denoising = False
|
||||
|
||||
if add_pipeline:
|
||||
st.write("__________________________")
|
||||
|
||||
version2 = "SDXL-Refiner"
|
||||
version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"])
|
||||
st.warning(
|
||||
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
|
||||
)
|
||||
st.write("**Refiner Options:**")
|
||||
|
||||
version_dict2 = VERSION2SPECS[version2]
|
||||
state2 = init_st(version_dict2)
|
||||
state2 = init_st(version_dict2, load_filter=False)
|
||||
st.info(state2["msg"])
|
||||
|
||||
stage2strength = st.number_input(
|
||||
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
|
||||
"**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
|
||||
)
|
||||
|
||||
sampler2 = init_sampling(
|
||||
sampler2, *_ = init_sampling(
|
||||
key=2,
|
||||
img2img_strength=stage2strength,
|
||||
use_identity_guider=not version_dict["is_guided"],
|
||||
get_num_samples=False,
|
||||
specify_num_samples=False,
|
||||
)
|
||||
st.write("__________________________")
|
||||
finish_denoising = st.checkbox("Finish denoising with refiner.", True)
|
||||
if not finish_denoising:
|
||||
stage2strength = None
|
||||
|
||||
if mode == "txt2img":
|
||||
out = run_txt2img(
|
||||
@@ -295,7 +321,8 @@ if __name__ == "__main__":
|
||||
version_dict,
|
||||
is_legacy=is_legacy,
|
||||
return_latents=add_pipeline,
|
||||
filter=filter,
|
||||
filter=state.get("filter"),
|
||||
stage2strength=stage2strength,
|
||||
)
|
||||
elif mode == "img2img":
|
||||
out = run_img2img(
|
||||
@@ -303,7 +330,8 @@ if __name__ == "__main__":
|
||||
version_dict,
|
||||
is_legacy=is_legacy,
|
||||
return_latents=add_pipeline,
|
||||
filter=filter,
|
||||
filter=state.get("filter"),
|
||||
stage2strength=stage2strength,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown mode {mode}")
|
||||
@@ -311,8 +339,9 @@ if __name__ == "__main__":
|
||||
samples, samples_z = out
|
||||
else:
|
||||
samples = out
|
||||
samples_z = None
|
||||
|
||||
if add_pipeline:
|
||||
if add_pipeline and samples_z is not None:
|
||||
st.write("**Running Refinement Stage**")
|
||||
samples = apply_refiner(
|
||||
samples_z,
|
||||
@@ -321,7 +350,8 @@ if __name__ == "__main__":
|
||||
samples_z.shape[0],
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt if is_legacy else "",
|
||||
filter=filter,
|
||||
filter=state.get("filter"),
|
||||
finish_denoising=finish_denoising,
|
||||
)
|
||||
|
||||
if save_locally and samples is not None:
|
||||
|
||||
@@ -1,29 +1,29 @@
|
||||
import os
|
||||
from typing import Union, List
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import torch
|
||||
from PIL import Image
|
||||
from einops import rearrange, repeat
|
||||
from imwatermark import WatermarkEncoder
|
||||
from omegaconf import OmegaConf, ListConfig
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
from torch import autocast
|
||||
from torchvision import transforms
|
||||
from torchvision.utils import make_grid
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||
from sgm.modules.diffusionmodules.sampling import (
|
||||
DPMPP2MSampler,
|
||||
DPMPP2SAncestralSampler,
|
||||
EulerAncestralSampler,
|
||||
EulerEDMSampler,
|
||||
HeunEDMSampler,
|
||||
EulerAncestralSampler,
|
||||
DPMPP2SAncestralSampler,
|
||||
DPMPP2MSampler,
|
||||
LinearMultistepSampler,
|
||||
)
|
||||
from sgm.util import append_dims
|
||||
from sgm.util import instantiate_from_config
|
||||
from sgm.util import append_dims, instantiate_from_config
|
||||
|
||||
|
||||
class WatermarkEmbedder:
|
||||
@@ -43,19 +43,19 @@ class WatermarkEmbedder:
|
||||
Returns:
|
||||
same as input but watermarked
|
||||
"""
|
||||
# watermarking libary expects input as cv2 format
|
||||
# watermarking libary expects input as cv2 BGR format
|
||||
squeeze = len(image.shape) == 4
|
||||
if squeeze:
|
||||
image = image[None, ...]
|
||||
n = image.shape[0]
|
||||
image_np = rearrange(
|
||||
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
||||
).numpy()
|
||||
).numpy()[:, :, :, ::-1]
|
||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||
for k in range(image_np.shape[0]):
|
||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||
image = torch.from_numpy(
|
||||
rearrange(image_np, "(n b) h w c -> n b c h w", n=n)
|
||||
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
|
||||
).to(image.device)
|
||||
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
||||
if squeeze:
|
||||
@@ -72,7 +72,7 @@ embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
def init_st(version_dict, load_ckpt=True):
|
||||
def init_st(version_dict, load_ckpt=True, load_filter=True):
|
||||
state = dict()
|
||||
if not "model" in state:
|
||||
config = version_dict["config"]
|
||||
@@ -85,9 +85,39 @@ def init_st(version_dict, load_ckpt=True):
|
||||
state["model"] = model
|
||||
state["ckpt"] = ckpt if load_ckpt else None
|
||||
state["config"] = config
|
||||
if load_filter:
|
||||
state["filter"] = DeepFloydDataFiltering(verbose=False)
|
||||
return state
|
||||
|
||||
|
||||
def load_model(model):
|
||||
model.cuda()
|
||||
|
||||
|
||||
lowvram_mode = False
|
||||
|
||||
|
||||
def set_lowvram_mode(mode):
|
||||
global lowvram_mode
|
||||
lowvram_mode = mode
|
||||
|
||||
|
||||
def initial_model_load(model):
|
||||
global lowvram_mode
|
||||
if lowvram_mode:
|
||||
model.model.half()
|
||||
else:
|
||||
model.cuda()
|
||||
return model
|
||||
|
||||
|
||||
def unload_model(model):
|
||||
global lowvram_mode
|
||||
if lowvram_mode:
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt=None, verbose=True):
|
||||
model = instantiate_from_config(config.model)
|
||||
|
||||
@@ -118,7 +148,7 @@ def load_model_from_config(config, ckpt=None, verbose=True):
|
||||
else:
|
||||
msg = None
|
||||
|
||||
model.cuda()
|
||||
model = initial_model_load(model)
|
||||
model.eval()
|
||||
return model, msg
|
||||
|
||||
@@ -170,19 +200,8 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
||||
value_dict["negative_aesthetic_score"] = 2.5
|
||||
|
||||
if key == "target_size_as_tuple":
|
||||
target_width = st.number_input(
|
||||
"target_width",
|
||||
value=init_dict["target_width"],
|
||||
min_value=16,
|
||||
)
|
||||
target_height = st.number_input(
|
||||
"target_height",
|
||||
value=init_dict["target_height"],
|
||||
min_value=16,
|
||||
)
|
||||
|
||||
value_dict["target_width"] = target_width
|
||||
value_dict["target_height"] = target_height
|
||||
value_dict["target_width"] = init_dict["target_width"]
|
||||
value_dict["target_height"] = init_dict["target_height"]
|
||||
|
||||
return value_dict
|
||||
|
||||
@@ -233,6 +252,36 @@ class Img2ImgDiscretizationWrapper:
|
||||
return sigmas
|
||||
|
||||
|
||||
class Txt2NoisyDiscretizationWrapper:
|
||||
"""
|
||||
wraps a discretizer, and prunes the sigmas
|
||||
params:
|
||||
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
|
||||
"""
|
||||
|
||||
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
|
||||
self.discretization = discretization
|
||||
self.strength = strength
|
||||
self.original_steps = original_steps
|
||||
assert 0.0 <= self.strength <= 1.0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# sigmas start large first, and decrease then
|
||||
sigmas = self.discretization(*args, **kwargs)
|
||||
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
if self.original_steps is None:
|
||||
steps = len(sigmas)
|
||||
else:
|
||||
steps = self.original_steps + 1
|
||||
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
|
||||
sigmas = sigmas[prune_index:]
|
||||
print("prune index:", prune_index)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
print(f"sigmas after pruning: ", sigmas)
|
||||
return sigmas
|
||||
|
||||
|
||||
def get_guider(key):
|
||||
guider = st.sidebar.selectbox(
|
||||
f"Discretization #{key}",
|
||||
@@ -275,16 +324,19 @@ def get_guider(key):
|
||||
|
||||
|
||||
def init_sampling(
|
||||
key=1, img2img_strength=1.0, use_identity_guider=False, get_num_samples=True
|
||||
key=1,
|
||||
img2img_strength=1.0,
|
||||
specify_num_samples=True,
|
||||
stage2strength=None,
|
||||
):
|
||||
if get_num_samples:
|
||||
num_rows = 1
|
||||
num_rows, num_cols = 1, 1
|
||||
if specify_num_samples:
|
||||
num_cols = st.number_input(
|
||||
f"num cols #{key}", value=2, min_value=1, max_value=10
|
||||
)
|
||||
|
||||
steps = st.sidebar.number_input(
|
||||
f"steps #{key}", value=50, min_value=1, max_value=1000
|
||||
f"steps #{key}", value=40, min_value=1, max_value=1000
|
||||
)
|
||||
sampler = st.sidebar.selectbox(
|
||||
f"Sampler #{key}",
|
||||
@@ -318,17 +370,17 @@ def init_sampling(
|
||||
sampler.discretization = Img2ImgDiscretizationWrapper(
|
||||
sampler.discretization, strength=img2img_strength
|
||||
)
|
||||
if get_num_samples:
|
||||
return num_rows, num_cols, sampler
|
||||
return sampler
|
||||
if stage2strength is not None:
|
||||
sampler.discretization = Txt2NoisyDiscretizationWrapper(
|
||||
sampler.discretization, strength=stage2strength, original_steps=steps
|
||||
)
|
||||
return sampler, num_rows, num_cols
|
||||
|
||||
|
||||
def get_discretization(discretization, key=1):
|
||||
if discretization == "LegacyDDPMDiscretization":
|
||||
use_new_range = st.checkbox(f"Start from highest noise level? #{key}", False)
|
||||
discretization_config = {
|
||||
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
||||
"params": {"legacy_range": not use_new_range},
|
||||
}
|
||||
elif discretization == "EDMDiscretization":
|
||||
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
|
||||
@@ -484,6 +536,7 @@ def do_sample(
|
||||
with precision_scope("cuda"):
|
||||
with model.ema_scope():
|
||||
num_samples = [num_samples]
|
||||
load_model(model.conditioner)
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
@@ -501,6 +554,7 @@ def do_sample(
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
)
|
||||
unload_model(model.conditioner)
|
||||
|
||||
for k in c:
|
||||
if not k == "crossattn":
|
||||
@@ -520,9 +574,16 @@ def do_sample(
|
||||
model.model, input, sigma, c, **additional_model_inputs
|
||||
)
|
||||
|
||||
load_model(model.denoiser)
|
||||
load_model(model.model)
|
||||
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||
unload_model(model.model)
|
||||
unload_model(model.denoiser)
|
||||
|
||||
load_model(model.first_stage_model)
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
unload_model(model.first_stage_model)
|
||||
|
||||
if filter is not None:
|
||||
samples = filter(samples)
|
||||
@@ -606,6 +667,7 @@ def do_img2img(
|
||||
return_latents=False,
|
||||
skip_encode=False,
|
||||
filter=None,
|
||||
add_noise=True,
|
||||
):
|
||||
st.text("Sampling")
|
||||
|
||||
@@ -614,6 +676,7 @@ def do_img2img(
|
||||
with torch.no_grad():
|
||||
with precision_scope("cuda"):
|
||||
with model.ema_scope():
|
||||
load_model(model.conditioner)
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
@@ -624,7 +687,7 @@ def do_img2img(
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
)
|
||||
|
||||
unload_model(model.conditioner)
|
||||
for k in c:
|
||||
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
|
||||
|
||||
@@ -633,28 +696,41 @@ def do_img2img(
|
||||
if skip_encode:
|
||||
z = img
|
||||
else:
|
||||
load_model(model.first_stage_model)
|
||||
z = model.encode_first_stage(img)
|
||||
unload_model(model.first_stage_model)
|
||||
|
||||
noise = torch.randn_like(z)
|
||||
sigmas = sampler.discretization(sampler.num_steps)
|
||||
|
||||
sigmas = sampler.discretization(sampler.num_steps).cuda()
|
||||
sigma = sigmas[0]
|
||||
|
||||
st.info(f"all sigmas: {sigmas}")
|
||||
st.info(f"noising sigma: {sigma}")
|
||||
|
||||
if offset_noise_level > 0.0:
|
||||
noise = noise + offset_noise_level * append_dims(
|
||||
torch.randn(z.shape[0], device=z.device), z.ndim
|
||||
)
|
||||
noised_z = z + noise * append_dims(sigma, z.ndim)
|
||||
noised_z = noised_z / torch.sqrt(
|
||||
1.0 + sigmas[0] ** 2.0
|
||||
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
||||
if add_noise:
|
||||
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
|
||||
noised_z = noised_z / torch.sqrt(
|
||||
1.0 + sigmas[0] ** 2.0
|
||||
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
||||
else:
|
||||
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
|
||||
def denoiser(x, sigma, c):
|
||||
return model.denoiser(model.model, x, sigma, c)
|
||||
|
||||
load_model(model.denoiser)
|
||||
load_model(model.model)
|
||||
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
||||
unload_model(model.model)
|
||||
unload_model(model.denoiser)
|
||||
|
||||
load_model(model.first_stage_model)
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
unload_model(model.first_stage_model)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
if filter is not None:
|
||||
|
||||
319
scripts/tests/attention.py
Normal file
319
scripts/tests/attention.py
Normal file
@@ -0,0 +1,319 @@
|
||||
import torch
|
||||
import einops
|
||||
from torch.backends.cuda import SDPBackend
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
from sgm.modules.attention import SpatialTransformer, BasicTransformerBlock
|
||||
|
||||
|
||||
def benchmark_attn():
|
||||
# Lets define a helpful benchmarking function:
|
||||
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
||||
t0 = benchmark.Timer(
|
||||
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
||||
)
|
||||
return t0.blocked_autorange().mean * 1e6
|
||||
|
||||
# Lets define the hyper-parameters of our input
|
||||
batch_size = 32
|
||||
max_sequence_len = 1024
|
||||
num_heads = 32
|
||||
embed_dimension = 32
|
||||
|
||||
dtype = torch.float16
|
||||
|
||||
query = torch.rand(
|
||||
batch_size,
|
||||
num_heads,
|
||||
max_sequence_len,
|
||||
embed_dimension,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
key = torch.rand(
|
||||
batch_size,
|
||||
num_heads,
|
||||
max_sequence_len,
|
||||
embed_dimension,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
value = torch.rand(
|
||||
batch_size,
|
||||
num_heads,
|
||||
max_sequence_len,
|
||||
embed_dimension,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
print(f"q/k/v shape:", query.shape, key.shape, value.shape)
|
||||
|
||||
# Lets explore the speed of each of the 3 implementations
|
||||
from torch.backends.cuda import SDPBackend, sdp_kernel
|
||||
|
||||
# Helpful arguments mapper
|
||||
backend_map = {
|
||||
SDPBackend.MATH: {
|
||||
"enable_math": True,
|
||||
"enable_flash": False,
|
||||
"enable_mem_efficient": False,
|
||||
},
|
||||
SDPBackend.FLASH_ATTENTION: {
|
||||
"enable_math": False,
|
||||
"enable_flash": True,
|
||||
"enable_mem_efficient": False,
|
||||
},
|
||||
SDPBackend.EFFICIENT_ATTENTION: {
|
||||
"enable_math": False,
|
||||
"enable_flash": False,
|
||||
"enable_mem_efficient": True,
|
||||
},
|
||||
}
|
||||
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
|
||||
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
||||
|
||||
print(
|
||||
f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
||||
)
|
||||
with profile(
|
||||
activities=activities, record_shapes=False, profile_memory=True
|
||||
) as prof:
|
||||
with record_function("Default detailed stats"):
|
||||
for _ in range(25):
|
||||
o = F.scaled_dot_product_attention(query, key, value)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
|
||||
print(
|
||||
f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
||||
)
|
||||
with sdp_kernel(**backend_map[SDPBackend.MATH]):
|
||||
with profile(
|
||||
activities=activities, record_shapes=False, profile_memory=True
|
||||
) as prof:
|
||||
with record_function("Math implmentation stats"):
|
||||
for _ in range(25):
|
||||
o = F.scaled_dot_product_attention(query, key, value)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
|
||||
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
|
||||
try:
|
||||
print(
|
||||
f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
||||
)
|
||||
except RuntimeError:
|
||||
print("FlashAttention is not supported. See warnings for reasons.")
|
||||
with profile(
|
||||
activities=activities, record_shapes=False, profile_memory=True
|
||||
) as prof:
|
||||
with record_function("FlashAttention stats"):
|
||||
for _ in range(25):
|
||||
o = F.scaled_dot_product_attention(query, key, value)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
|
||||
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
|
||||
try:
|
||||
print(
|
||||
f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
||||
)
|
||||
except RuntimeError:
|
||||
print("EfficientAttention is not supported. See warnings for reasons.")
|
||||
with profile(
|
||||
activities=activities, record_shapes=False, profile_memory=True
|
||||
) as prof:
|
||||
with record_function("EfficientAttention stats"):
|
||||
for _ in range(25):
|
||||
o = F.scaled_dot_product_attention(query, key, value)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
|
||||
|
||||
def run_model(model, x, context):
|
||||
return model(x, context)
|
||||
|
||||
|
||||
def benchmark_transformer_blocks():
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
||||
t0 = benchmark.Timer(
|
||||
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
||||
)
|
||||
return t0.blocked_autorange().mean * 1e6
|
||||
|
||||
checkpoint = True
|
||||
compile = False
|
||||
|
||||
batch_size = 32
|
||||
h, w = 64, 64
|
||||
context_len = 77
|
||||
embed_dimension = 1024
|
||||
context_dim = 1024
|
||||
d_head = 64
|
||||
|
||||
transformer_depth = 4
|
||||
|
||||
n_heads = embed_dimension // d_head
|
||||
|
||||
dtype = torch.float16
|
||||
|
||||
model_native = SpatialTransformer(
|
||||
embed_dimension,
|
||||
n_heads,
|
||||
d_head,
|
||||
context_dim=context_dim,
|
||||
use_linear=True,
|
||||
use_checkpoint=checkpoint,
|
||||
attn_type="softmax",
|
||||
depth=transformer_depth,
|
||||
sdp_backend=SDPBackend.FLASH_ATTENTION,
|
||||
).to(device)
|
||||
model_efficient_attn = SpatialTransformer(
|
||||
embed_dimension,
|
||||
n_heads,
|
||||
d_head,
|
||||
context_dim=context_dim,
|
||||
use_linear=True,
|
||||
depth=transformer_depth,
|
||||
use_checkpoint=checkpoint,
|
||||
attn_type="softmax-xformers",
|
||||
).to(device)
|
||||
if not checkpoint and compile:
|
||||
print("compiling models")
|
||||
model_native = torch.compile(model_native)
|
||||
model_efficient_attn = torch.compile(model_efficient_attn)
|
||||
|
||||
x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
|
||||
c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
|
||||
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
|
||||
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
print(
|
||||
f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
|
||||
)
|
||||
print(
|
||||
f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
|
||||
)
|
||||
|
||||
print(75 * "+")
|
||||
print("NATIVE")
|
||||
print(75 * "+")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
with profile(
|
||||
activities=activities, record_shapes=False, profile_memory=True
|
||||
) as prof:
|
||||
with record_function("NativeAttention stats"):
|
||||
for _ in range(25):
|
||||
model_native(x, c)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
|
||||
|
||||
print(75 * "+")
|
||||
print("Xformers")
|
||||
print(75 * "+")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
with profile(
|
||||
activities=activities, record_shapes=False, profile_memory=True
|
||||
) as prof:
|
||||
with record_function("xformers stats"):
|
||||
for _ in range(25):
|
||||
model_efficient_attn(x, c)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
|
||||
|
||||
|
||||
def test01():
|
||||
# conv1x1 vs linear
|
||||
from sgm.util import count_params
|
||||
|
||||
conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda()
|
||||
print(count_params(conv))
|
||||
linear = torch.nn.Linear(3, 32).cuda()
|
||||
print(count_params(linear))
|
||||
|
||||
print(conv.weight.shape)
|
||||
|
||||
# use same initialization
|
||||
linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
|
||||
linear.bias = torch.nn.Parameter(conv.bias)
|
||||
|
||||
print(linear.weight.shape)
|
||||
|
||||
x = torch.randn(11, 3, 64, 64).cuda()
|
||||
|
||||
xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous()
|
||||
print(xr.shape)
|
||||
out_linear = linear(xr)
|
||||
print(out_linear.mean(), out_linear.shape)
|
||||
|
||||
out_conv = conv(x)
|
||||
print(out_conv.mean(), out_conv.shape)
|
||||
print("done with test01.\n")
|
||||
|
||||
|
||||
def test02():
|
||||
# try cosine flash attention
|
||||
import time
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
print("testing cosine flash attention...")
|
||||
DIM = 1024
|
||||
SEQLEN = 4096
|
||||
BS = 16
|
||||
|
||||
print(" softmax (vanilla) first...")
|
||||
model = BasicTransformerBlock(
|
||||
dim=DIM,
|
||||
n_heads=16,
|
||||
d_head=64,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
attn_mode="softmax",
|
||||
).cuda()
|
||||
try:
|
||||
x = torch.randn(BS, SEQLEN, DIM).cuda()
|
||||
tic = time.time()
|
||||
y = model(x)
|
||||
toc = time.time()
|
||||
print(y.shape, toc - tic)
|
||||
except RuntimeError as e:
|
||||
# likely oom
|
||||
print(str(e))
|
||||
|
||||
print("\n now flash-cosine...")
|
||||
model = BasicTransformerBlock(
|
||||
dim=DIM,
|
||||
n_heads=16,
|
||||
d_head=64,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
attn_mode="flash-cosine",
|
||||
).cuda()
|
||||
x = torch.randn(BS, SEQLEN, DIM).cuda()
|
||||
tic = time.time()
|
||||
y = model(x)
|
||||
toc = time.time()
|
||||
print(y.shape, toc - tic)
|
||||
print("done with test02.\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test01()
|
||||
# test02()
|
||||
# test03()
|
||||
|
||||
# benchmark_attn()
|
||||
benchmark_transformer_blocks()
|
||||
|
||||
print("done.")
|
||||
0
scripts/util/__init__.py
Normal file
0
scripts/util/__init__.py
Normal file
0
scripts/util/detection/__init__.py
Normal file
0
scripts/util/detection/__init__.py
Normal file
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
import clip
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
import clip
|
||||
|
||||
RESOURCES_ROOT = "scripts/util/detection/"
|
||||
|
||||
|
||||
13
setup.py
13
setup.py
@@ -1,13 +0,0 @@
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
name="sgm",
|
||||
version="0.0.1",
|
||||
packages=find_packages(),
|
||||
python_requires=">=3.8",
|
||||
py_modules=["sgm"],
|
||||
description="Stability Generative Models",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/Stability-AI/generative-models",
|
||||
)
|
||||
@@ -1,3 +1,4 @@
|
||||
from .data import StableDataModuleFromConfig
|
||||
from .models import AutoencodingEngine, DiffusionEngine
|
||||
from .util import instantiate_from_config
|
||||
from .util import get_configs_path, instantiate_from_config
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torchvision
|
||||
import pytorch_lightning as pl
|
||||
from torchvision import transforms
|
||||
import torchvision
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class CIFAR10DataDictWrapper(Dataset):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torchvision
|
||||
import pytorch_lightning as pl
|
||||
from torchvision import transforms
|
||||
import torchvision
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class MNISTDataDictWrapper(Dataset):
|
||||
|
||||
388
sgm/inference/api.py
Normal file
388
sgm/inference/api.py
Normal file
@@ -0,0 +1,388 @@
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
from omegaconf import OmegaConf
|
||||
import pathlib
|
||||
from sgm.inference.helpers import (
|
||||
do_sample,
|
||||
do_img2img,
|
||||
Img2ImgDiscretizationWrapper,
|
||||
)
|
||||
from sgm.modules.diffusionmodules.sampling import (
|
||||
EulerEDMSampler,
|
||||
HeunEDMSampler,
|
||||
EulerAncestralSampler,
|
||||
DPMPP2SAncestralSampler,
|
||||
DPMPP2MSampler,
|
||||
LinearMultistepSampler,
|
||||
)
|
||||
from sgm.util import load_model_from_config
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ModelArchitecture(str, Enum):
|
||||
SD_2_1 = "stable-diffusion-v2-1"
|
||||
SD_2_1_768 = "stable-diffusion-v2-1-768"
|
||||
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
|
||||
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
|
||||
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
|
||||
|
||||
|
||||
class Sampler(str, Enum):
|
||||
EULER_EDM = "EulerEDMSampler"
|
||||
HEUN_EDM = "HeunEDMSampler"
|
||||
EULER_ANCESTRAL = "EulerAncestralSampler"
|
||||
DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
|
||||
DPMPP2M = "DPMPP2MSampler"
|
||||
LINEAR_MULTISTEP = "LinearMultistepSampler"
|
||||
|
||||
|
||||
class Discretization(str, Enum):
|
||||
LEGACY_DDPM = "LegacyDDPMDiscretization"
|
||||
EDM = "EDMDiscretization"
|
||||
|
||||
|
||||
class Guider(str, Enum):
|
||||
VANILLA = "VanillaCFG"
|
||||
IDENTITY = "IdentityGuider"
|
||||
|
||||
|
||||
class Thresholder(str, Enum):
|
||||
NONE = "None"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingParams:
|
||||
width: int = 1024
|
||||
height: int = 1024
|
||||
steps: int = 50
|
||||
sampler: Sampler = Sampler.DPMPP2M
|
||||
discretization: Discretization = Discretization.LEGACY_DDPM
|
||||
guider: Guider = Guider.VANILLA
|
||||
thresholder: Thresholder = Thresholder.NONE
|
||||
scale: float = 6.0
|
||||
aesthetic_score: float = 5.0
|
||||
negative_aesthetic_score: float = 5.0
|
||||
img2img_strength: float = 1.0
|
||||
orig_width: int = 1024
|
||||
orig_height: int = 1024
|
||||
crop_coords_top: int = 0
|
||||
crop_coords_left: int = 0
|
||||
sigma_min: float = 0.0292
|
||||
sigma_max: float = 14.6146
|
||||
rho: float = 3.0
|
||||
s_churn: float = 0.0
|
||||
s_tmin: float = 0.0
|
||||
s_tmax: float = 999.0
|
||||
s_noise: float = 1.0
|
||||
eta: float = 1.0
|
||||
order: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingSpec:
|
||||
width: int
|
||||
height: int
|
||||
channels: int
|
||||
factor: int
|
||||
is_legacy: bool
|
||||
config: str
|
||||
ckpt: str
|
||||
is_guided: bool
|
||||
|
||||
|
||||
model_specs = {
|
||||
ModelArchitecture.SD_2_1: SamplingSpec(
|
||||
height=512,
|
||||
width=512,
|
||||
channels=4,
|
||||
factor=8,
|
||||
is_legacy=True,
|
||||
config="sd_2_1.yaml",
|
||||
ckpt="v2-1_512-ema-pruned.safetensors",
|
||||
is_guided=True,
|
||||
),
|
||||
ModelArchitecture.SD_2_1_768: SamplingSpec(
|
||||
height=768,
|
||||
width=768,
|
||||
channels=4,
|
||||
factor=8,
|
||||
is_legacy=True,
|
||||
config="sd_2_1_768.yaml",
|
||||
ckpt="v2-1_768-ema-pruned.safetensors",
|
||||
is_guided=True,
|
||||
),
|
||||
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
|
||||
height=1024,
|
||||
width=1024,
|
||||
channels=4,
|
||||
factor=8,
|
||||
is_legacy=False,
|
||||
config="sd_xl_base.yaml",
|
||||
ckpt="sd_xl_base_0.9.safetensors",
|
||||
is_guided=True,
|
||||
),
|
||||
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
|
||||
height=1024,
|
||||
width=1024,
|
||||
channels=4,
|
||||
factor=8,
|
||||
is_legacy=True,
|
||||
config="sd_xl_refiner.yaml",
|
||||
ckpt="sd_xl_refiner_0.9.safetensors",
|
||||
is_guided=True,
|
||||
),
|
||||
ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
|
||||
height=1024,
|
||||
width=1024,
|
||||
channels=4,
|
||||
factor=8,
|
||||
is_legacy=False,
|
||||
config="sd_xl_base.yaml",
|
||||
ckpt="sd_xl_base_1.0.safetensors",
|
||||
is_guided=True,
|
||||
),
|
||||
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
|
||||
height=1024,
|
||||
width=1024,
|
||||
channels=4,
|
||||
factor=8,
|
||||
is_legacy=True,
|
||||
config="sd_xl_refiner.yaml",
|
||||
ckpt="sd_xl_refiner_1.0.safetensors",
|
||||
is_guided=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class SamplingPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
model_id: ModelArchitecture,
|
||||
model_path="checkpoints",
|
||||
config_path="configs/inference",
|
||||
device="cuda",
|
||||
use_fp16=True,
|
||||
) -> None:
|
||||
if model_id not in model_specs:
|
||||
raise ValueError(f"Model {model_id} not supported")
|
||||
self.model_id = model_id
|
||||
self.specs = model_specs[self.model_id]
|
||||
self.config = str(pathlib.Path(config_path, self.specs.config))
|
||||
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
|
||||
self.device = device
|
||||
self.model = self._load_model(device=device, use_fp16=use_fp16)
|
||||
|
||||
def _load_model(self, device="cuda", use_fp16=True):
|
||||
config = OmegaConf.load(self.config)
|
||||
model = load_model_from_config(config, self.ckpt)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {self.model_id} could not be loaded")
|
||||
model.to(device)
|
||||
if use_fp16:
|
||||
model.conditioner.half()
|
||||
model.model.half()
|
||||
return model
|
||||
|
||||
def text_to_image(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
):
|
||||
sampler = get_sampler_config(params)
|
||||
value_dict = asdict(params)
|
||||
value_dict["prompt"] = prompt
|
||||
value_dict["negative_prompt"] = negative_prompt
|
||||
value_dict["target_width"] = params.width
|
||||
value_dict["target_height"] = params.height
|
||||
return do_sample(
|
||||
self.model,
|
||||
sampler,
|
||||
value_dict,
|
||||
samples,
|
||||
params.height,
|
||||
params.width,
|
||||
self.specs.channels,
|
||||
self.specs.factor,
|
||||
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||
return_latents=return_latents,
|
||||
filter=None,
|
||||
)
|
||||
|
||||
def image_to_image(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
image,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
):
|
||||
sampler = get_sampler_config(params)
|
||||
|
||||
if params.img2img_strength < 1.0:
|
||||
sampler.discretization = Img2ImgDiscretizationWrapper(
|
||||
sampler.discretization,
|
||||
strength=params.img2img_strength,
|
||||
)
|
||||
height, width = image.shape[2], image.shape[3]
|
||||
value_dict = asdict(params)
|
||||
value_dict["prompt"] = prompt
|
||||
value_dict["negative_prompt"] = negative_prompt
|
||||
value_dict["target_width"] = width
|
||||
value_dict["target_height"] = height
|
||||
return do_img2img(
|
||||
image,
|
||||
self.model,
|
||||
sampler,
|
||||
value_dict,
|
||||
samples,
|
||||
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||
return_latents=return_latents,
|
||||
filter=None,
|
||||
)
|
||||
|
||||
def refiner(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
image,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
):
|
||||
sampler = get_sampler_config(params)
|
||||
value_dict = {
|
||||
"orig_width": image.shape[3] * 8,
|
||||
"orig_height": image.shape[2] * 8,
|
||||
"target_width": image.shape[3] * 8,
|
||||
"target_height": image.shape[2] * 8,
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"crop_coords_top": 0,
|
||||
"crop_coords_left": 0,
|
||||
"aesthetic_score": 6.0,
|
||||
"negative_aesthetic_score": 2.5,
|
||||
}
|
||||
|
||||
return do_img2img(
|
||||
image,
|
||||
self.model,
|
||||
sampler,
|
||||
value_dict,
|
||||
samples,
|
||||
skip_encode=True,
|
||||
return_latents=return_latents,
|
||||
filter=None,
|
||||
)
|
||||
|
||||
|
||||
def get_guider_config(params: SamplingParams):
|
||||
if params.guider == Guider.IDENTITY:
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
||||
}
|
||||
elif params.guider == Guider.VANILLA:
|
||||
scale = params.scale
|
||||
|
||||
thresholder = params.thresholder
|
||||
|
||||
if thresholder == Thresholder.NONE:
|
||||
dyn_thresh_config = {
|
||||
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
||||
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return guider_config
|
||||
|
||||
|
||||
def get_discretization_config(params: SamplingParams):
|
||||
if params.discretization == Discretization.LEGACY_DDPM:
|
||||
discretization_config = {
|
||||
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
||||
}
|
||||
elif params.discretization == Discretization.EDM:
|
||||
discretization_config = {
|
||||
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
||||
"params": {
|
||||
"sigma_min": params.sigma_min,
|
||||
"sigma_max": params.sigma_max,
|
||||
"rho": params.rho,
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"unknown discretization {params.discretization}")
|
||||
return discretization_config
|
||||
|
||||
|
||||
def get_sampler_config(params: SamplingParams):
|
||||
discretization_config = get_discretization_config(params)
|
||||
guider_config = get_guider_config(params)
|
||||
sampler = None
|
||||
if params.sampler == Sampler.EULER_EDM:
|
||||
return EulerEDMSampler(
|
||||
num_steps=params.steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
s_churn=params.s_churn,
|
||||
s_tmin=params.s_tmin,
|
||||
s_tmax=params.s_tmax,
|
||||
s_noise=params.s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
if params.sampler == Sampler.HEUN_EDM:
|
||||
return HeunEDMSampler(
|
||||
num_steps=params.steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
s_churn=params.s_churn,
|
||||
s_tmin=params.s_tmin,
|
||||
s_tmax=params.s_tmax,
|
||||
s_noise=params.s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
if params.sampler == Sampler.EULER_ANCESTRAL:
|
||||
return EulerAncestralSampler(
|
||||
num_steps=params.steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
eta=params.eta,
|
||||
s_noise=params.s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
|
||||
return DPMPP2SAncestralSampler(
|
||||
num_steps=params.steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
eta=params.eta,
|
||||
s_noise=params.s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
if params.sampler == Sampler.DPMPP2M:
|
||||
return DPMPP2MSampler(
|
||||
num_steps=params.steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
verbose=True,
|
||||
)
|
||||
if params.sampler == Sampler.LINEAR_MULTISTEP:
|
||||
return LinearMultistepSampler(
|
||||
num_steps=params.steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
order=params.order,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
raise ValueError(f"unknown sampler {params.sampler}!")
|
||||
305
sgm/inference/helpers.py
Normal file
305
sgm/inference/helpers.py
Normal file
@@ -0,0 +1,305 @@
|
||||
import os
|
||||
from typing import Union, List, Optional
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from einops import rearrange
|
||||
from imwatermark import WatermarkEncoder
|
||||
from omegaconf import ListConfig
|
||||
from torch import autocast
|
||||
|
||||
from sgm.util import append_dims
|
||||
|
||||
|
||||
class WatermarkEmbedder:
|
||||
def __init__(self, watermark):
|
||||
self.watermark = watermark
|
||||
self.num_bits = len(WATERMARK_BITS)
|
||||
self.encoder = WatermarkEncoder()
|
||||
self.encoder.set_watermark("bits", self.watermark)
|
||||
|
||||
def __call__(self, image: torch.Tensor):
|
||||
"""
|
||||
Adds a predefined watermark to the input image
|
||||
|
||||
Args:
|
||||
image: ([N,] B, C, H, W) in range [0, 1]
|
||||
|
||||
Returns:
|
||||
same as input but watermarked
|
||||
"""
|
||||
# watermarking libary expects input as cv2 BGR format
|
||||
squeeze = len(image.shape) == 4
|
||||
if squeeze:
|
||||
image = image[None, ...]
|
||||
n = image.shape[0]
|
||||
image_np = rearrange(
|
||||
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
||||
).numpy()[:, :, :, ::-1]
|
||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||
for k in range(image_np.shape[0]):
|
||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||
image = torch.from_numpy(
|
||||
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
|
||||
).to(image.device)
|
||||
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
||||
if squeeze:
|
||||
image = image[0]
|
||||
return image
|
||||
|
||||
|
||||
# A fixed 48-bit message that was choosen at random
|
||||
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
||||
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
||||
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
||||
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
||||
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
||||
|
||||
|
||||
def get_unique_embedder_keys_from_conditioner(conditioner):
|
||||
return list({x.input_key for x in conditioner.embedders})
|
||||
|
||||
|
||||
def perform_save_locally(save_path, samples):
|
||||
os.makedirs(os.path.join(save_path), exist_ok=True)
|
||||
base_count = len(os.listdir(os.path.join(save_path)))
|
||||
samples = embed_watermark(samples)
|
||||
for sample in samples:
|
||||
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
||||
Image.fromarray(sample.astype(np.uint8)).save(
|
||||
os.path.join(save_path, f"{base_count:09}.png")
|
||||
)
|
||||
base_count += 1
|
||||
|
||||
|
||||
class Img2ImgDiscretizationWrapper:
|
||||
"""
|
||||
wraps a discretizer, and prunes the sigmas
|
||||
params:
|
||||
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
||||
"""
|
||||
|
||||
def __init__(self, discretization, strength: float = 1.0):
|
||||
self.discretization = discretization
|
||||
self.strength = strength
|
||||
assert 0.0 <= self.strength <= 1.0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# sigmas start large first, and decrease then
|
||||
sigmas = self.discretization(*args, **kwargs)
|
||||
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
||||
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
print(f"sigmas after pruning: ", sigmas)
|
||||
return sigmas
|
||||
|
||||
|
||||
def do_sample(
|
||||
model,
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
H,
|
||||
W,
|
||||
C,
|
||||
F,
|
||||
force_uc_zero_embeddings: Optional[List] = None,
|
||||
batch2model_input: Optional[List] = None,
|
||||
return_latents=False,
|
||||
filter=None,
|
||||
device="cuda",
|
||||
):
|
||||
if force_uc_zero_embeddings is None:
|
||||
force_uc_zero_embeddings = []
|
||||
if batch2model_input is None:
|
||||
batch2model_input = []
|
||||
|
||||
with torch.no_grad():
|
||||
with autocast(device) as precision_scope:
|
||||
with model.ema_scope():
|
||||
num_samples = [num_samples]
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
num_samples,
|
||||
)
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
print(key, batch[key].shape)
|
||||
elif isinstance(batch[key], list):
|
||||
print(key, [len(l) for l in batch[key]])
|
||||
else:
|
||||
print(key, batch[key])
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
)
|
||||
|
||||
for k in c:
|
||||
if not k == "crossattn":
|
||||
c[k], uc[k] = map(
|
||||
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
|
||||
)
|
||||
|
||||
additional_model_inputs = {}
|
||||
for k in batch2model_input:
|
||||
additional_model_inputs[k] = batch[k]
|
||||
|
||||
shape = (math.prod(num_samples), C, H // F, W // F)
|
||||
randn = torch.randn(shape).to(device)
|
||||
|
||||
def denoiser(input, sigma, c):
|
||||
return model.denoiser(
|
||||
model.model, input, sigma, c, **additional_model_inputs
|
||||
)
|
||||
|
||||
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
if filter is not None:
|
||||
samples = filter(samples)
|
||||
|
||||
if return_latents:
|
||||
return samples, samples_z
|
||||
return samples
|
||||
|
||||
|
||||
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
# Hardcoded demo setups; might undergo some changes in the future
|
||||
|
||||
batch = {}
|
||||
batch_uc = {}
|
||||
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
batch["txt"] = (
|
||||
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
)
|
||||
batch_uc["txt"] = (
|
||||
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
)
|
||||
elif key == "original_size_as_tuple":
|
||||
batch["original_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
elif key == "crop_coords_top_left":
|
||||
batch["crop_coords_top_left"] = (
|
||||
torch.tensor(
|
||||
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
||||
)
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
elif key == "aesthetic_score":
|
||||
batch["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
batch_uc["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["negative_aesthetic_score"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
|
||||
elif key == "target_size_as_tuple":
|
||||
batch["target_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
else:
|
||||
batch[key] = value_dict[key]
|
||||
|
||||
for key in batch.keys():
|
||||
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
||||
batch_uc[key] = torch.clone(batch[key])
|
||||
return batch, batch_uc
|
||||
|
||||
|
||||
def get_input_image_tensor(image: Image.Image, device="cuda"):
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h})")
|
||||
width, height = map(
|
||||
lambda x: x - x % 64, (w, h)
|
||||
) # resize to integer multiple of 64
|
||||
image = image.resize((width, height))
|
||||
image_array = np.array(image.convert("RGB"))
|
||||
image_array = image_array[None].transpose(0, 3, 1, 2)
|
||||
image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
return image_tensor.to(device)
|
||||
|
||||
|
||||
def do_img2img(
|
||||
img,
|
||||
model,
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
force_uc_zero_embeddings=[],
|
||||
additional_kwargs={},
|
||||
offset_noise_level: float = 0.0,
|
||||
return_latents=False,
|
||||
skip_encode=False,
|
||||
filter=None,
|
||||
device="cuda",
|
||||
):
|
||||
with torch.no_grad():
|
||||
with autocast(device) as precision_scope:
|
||||
with model.ema_scope():
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
[num_samples],
|
||||
)
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
)
|
||||
|
||||
for k in c:
|
||||
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
|
||||
|
||||
for k in additional_kwargs:
|
||||
c[k] = uc[k] = additional_kwargs[k]
|
||||
if skip_encode:
|
||||
z = img
|
||||
else:
|
||||
z = model.encode_first_stage(img)
|
||||
noise = torch.randn_like(z)
|
||||
sigmas = sampler.discretization(sampler.num_steps)
|
||||
sigma = sigmas[0].to(z.device)
|
||||
|
||||
if offset_noise_level > 0.0:
|
||||
noise = noise + offset_noise_level * append_dims(
|
||||
torch.randn(z.shape[0], device=z.device), z.ndim
|
||||
)
|
||||
noised_z = z + noise * append_dims(sigma, z.ndim)
|
||||
noised_z = noised_z / torch.sqrt(
|
||||
1.0 + sigmas[0] ** 2.0
|
||||
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
||||
|
||||
def denoiser(x, sigma, c):
|
||||
return model.denoiser(model.model, x, sigma, c)
|
||||
|
||||
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
if filter is not None:
|
||||
samples = filter(samples)
|
||||
|
||||
if return_latents:
|
||||
return samples, samples_z
|
||||
return samples
|
||||
@@ -258,14 +258,10 @@ class DiffusionEngine(pl.LightningModule):
|
||||
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
elif isinstance(x, Union[List, ListConfig]):
|
||||
elif isinstance(x, (List, ListConfig)):
|
||||
if isinstance(x[0], str):
|
||||
# strings
|
||||
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
||||
elif isinstance(x[0], Union[ListConfig, List]):
|
||||
# # case: videos processed
|
||||
x = [xx[0] for xx in x]
|
||||
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
|
||||
@@ -631,317 +631,3 @@ class SpatialTransformer(nn.Module):
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
def benchmark_attn():
|
||||
# Lets define a helpful benchmarking function:
|
||||
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
||||
t0 = benchmark.Timer(
|
||||
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
||||
)
|
||||
return t0.blocked_autorange().mean * 1e6
|
||||
|
||||
# Lets define the hyper-parameters of our input
|
||||
batch_size = 32
|
||||
max_sequence_len = 1024
|
||||
num_heads = 32
|
||||
embed_dimension = 32
|
||||
|
||||
dtype = torch.float16
|
||||
|
||||
query = torch.rand(
|
||||
batch_size,
|
||||
num_heads,
|
||||
max_sequence_len,
|
||||
embed_dimension,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
key = torch.rand(
|
||||
batch_size,
|
||||
num_heads,
|
||||
max_sequence_len,
|
||||
embed_dimension,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
value = torch.rand(
|
||||
batch_size,
|
||||
num_heads,
|
||||
max_sequence_len,
|
||||
embed_dimension,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
print(f"q/k/v shape:", query.shape, key.shape, value.shape)
|
||||
|
||||
# Lets explore the speed of each of the 3 implementations
|
||||
from torch.backends.cuda import SDPBackend, sdp_kernel
|
||||
|
||||
# Helpful arguments mapper
|
||||
backend_map = {
|
||||
SDPBackend.MATH: {
|
||||
"enable_math": True,
|
||||
"enable_flash": False,
|
||||
"enable_mem_efficient": False,
|
||||
},
|
||||
SDPBackend.FLASH_ATTENTION: {
|
||||
"enable_math": False,
|
||||
"enable_flash": True,
|
||||
"enable_mem_efficient": False,
|
||||
},
|
||||
SDPBackend.EFFICIENT_ATTENTION: {
|
||||
"enable_math": False,
|
||||
"enable_flash": False,
|
||||
"enable_mem_efficient": True,
|
||||
},
|
||||
}
|
||||
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
|
||||
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
||||
|
||||
print(
|
||||
f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
||||
)
|
||||
with profile(
|
||||
activities=activities, record_shapes=False, profile_memory=True
|
||||
) as prof:
|
||||
with record_function("Default detailed stats"):
|
||||
for _ in range(25):
|
||||
o = F.scaled_dot_product_attention(query, key, value)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
|
||||
print(
|
||||
f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
||||
)
|
||||
with sdp_kernel(**backend_map[SDPBackend.MATH]):
|
||||
with profile(
|
||||
activities=activities, record_shapes=False, profile_memory=True
|
||||
) as prof:
|
||||
with record_function("Math implmentation stats"):
|
||||
for _ in range(25):
|
||||
o = F.scaled_dot_product_attention(query, key, value)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
|
||||
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
|
||||
try:
|
||||
print(
|
||||
f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
||||
)
|
||||
except RuntimeError:
|
||||
print("FlashAttention is not supported. See warnings for reasons.")
|
||||
with profile(
|
||||
activities=activities, record_shapes=False, profile_memory=True
|
||||
) as prof:
|
||||
with record_function("FlashAttention stats"):
|
||||
for _ in range(25):
|
||||
o = F.scaled_dot_product_attention(query, key, value)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
|
||||
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
|
||||
try:
|
||||
print(
|
||||
f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
||||
)
|
||||
except RuntimeError:
|
||||
print("EfficientAttention is not supported. See warnings for reasons.")
|
||||
with profile(
|
||||
activities=activities, record_shapes=False, profile_memory=True
|
||||
) as prof:
|
||||
with record_function("EfficientAttention stats"):
|
||||
for _ in range(25):
|
||||
o = F.scaled_dot_product_attention(query, key, value)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
|
||||
|
||||
def run_model(model, x, context):
|
||||
return model(x, context)
|
||||
|
||||
|
||||
def benchmark_transformer_blocks():
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
||||
t0 = benchmark.Timer(
|
||||
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
||||
)
|
||||
return t0.blocked_autorange().mean * 1e6
|
||||
|
||||
checkpoint = True
|
||||
compile = False
|
||||
|
||||
batch_size = 32
|
||||
h, w = 64, 64
|
||||
context_len = 77
|
||||
embed_dimension = 1024
|
||||
context_dim = 1024
|
||||
d_head = 64
|
||||
|
||||
transformer_depth = 4
|
||||
|
||||
n_heads = embed_dimension // d_head
|
||||
|
||||
dtype = torch.float16
|
||||
|
||||
model_native = SpatialTransformer(
|
||||
embed_dimension,
|
||||
n_heads,
|
||||
d_head,
|
||||
context_dim=context_dim,
|
||||
use_linear=True,
|
||||
use_checkpoint=checkpoint,
|
||||
attn_type="softmax",
|
||||
depth=transformer_depth,
|
||||
sdp_backend=SDPBackend.FLASH_ATTENTION,
|
||||
).to(device)
|
||||
model_efficient_attn = SpatialTransformer(
|
||||
embed_dimension,
|
||||
n_heads,
|
||||
d_head,
|
||||
context_dim=context_dim,
|
||||
use_linear=True,
|
||||
depth=transformer_depth,
|
||||
use_checkpoint=checkpoint,
|
||||
attn_type="softmax-xformers",
|
||||
).to(device)
|
||||
if not checkpoint and compile:
|
||||
print("compiling models")
|
||||
model_native = torch.compile(model_native)
|
||||
model_efficient_attn = torch.compile(model_efficient_attn)
|
||||
|
||||
x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
|
||||
c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
|
||||
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
|
||||
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
print(
|
||||
f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
|
||||
)
|
||||
print(
|
||||
f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
|
||||
)
|
||||
|
||||
print(75 * "+")
|
||||
print("NATIVE")
|
||||
print(75 * "+")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
with profile(
|
||||
activities=activities, record_shapes=False, profile_memory=True
|
||||
) as prof:
|
||||
with record_function("NativeAttention stats"):
|
||||
for _ in range(25):
|
||||
model_native(x, c)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
|
||||
|
||||
print(75 * "+")
|
||||
print("Xformers")
|
||||
print(75 * "+")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
with profile(
|
||||
activities=activities, record_shapes=False, profile_memory=True
|
||||
) as prof:
|
||||
with record_function("xformers stats"):
|
||||
for _ in range(25):
|
||||
model_efficient_attn(x, c)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
|
||||
|
||||
|
||||
def test01():
|
||||
# conv1x1 vs linear
|
||||
from ..util import count_params
|
||||
|
||||
conv = nn.Conv2d(3, 32, kernel_size=1).cuda()
|
||||
print(count_params(conv))
|
||||
linear = torch.nn.Linear(3, 32).cuda()
|
||||
print(count_params(linear))
|
||||
|
||||
print(conv.weight.shape)
|
||||
|
||||
# use same initialization
|
||||
linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
|
||||
linear.bias = torch.nn.Parameter(conv.bias)
|
||||
|
||||
print(linear.weight.shape)
|
||||
|
||||
x = torch.randn(11, 3, 64, 64).cuda()
|
||||
|
||||
xr = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
||||
print(xr.shape)
|
||||
out_linear = linear(xr)
|
||||
print(out_linear.mean(), out_linear.shape)
|
||||
|
||||
out_conv = conv(x)
|
||||
print(out_conv.mean(), out_conv.shape)
|
||||
print("done with test01.\n")
|
||||
|
||||
|
||||
def test02():
|
||||
# try cosine flash attention
|
||||
import time
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
print("testing cosine flash attention...")
|
||||
DIM = 1024
|
||||
SEQLEN = 4096
|
||||
BS = 16
|
||||
|
||||
print(" softmax (vanilla) first...")
|
||||
model = BasicTransformerBlock(
|
||||
dim=DIM,
|
||||
n_heads=16,
|
||||
d_head=64,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
attn_mode="softmax",
|
||||
).cuda()
|
||||
try:
|
||||
x = torch.randn(BS, SEQLEN, DIM).cuda()
|
||||
tic = time.time()
|
||||
y = model(x)
|
||||
toc = time.time()
|
||||
print(y.shape, toc - tic)
|
||||
except RuntimeError as e:
|
||||
# likely oom
|
||||
print(str(e))
|
||||
|
||||
print("\n now flash-cosine...")
|
||||
model = BasicTransformerBlock(
|
||||
dim=DIM,
|
||||
n_heads=16,
|
||||
d_head=64,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
attn_mode="flash-cosine",
|
||||
).cuda()
|
||||
x = torch.randn(BS, SEQLEN, DIM).cuda()
|
||||
tic = time.time()
|
||||
y = model(x)
|
||||
toc = time.time()
|
||||
print(y.shape, toc - tic)
|
||||
print("done with test02.\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test01()
|
||||
# test02()
|
||||
# test03()
|
||||
|
||||
# benchmark_attn()
|
||||
benchmark_transformer_blocks()
|
||||
|
||||
print("done.")
|
||||
|
||||
@@ -3,11 +3,11 @@ from typing import Any, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
||||
from taming.modules.losses.lpips import LPIPS
|
||||
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
|
||||
|
||||
from ....util import default, instantiate_from_config
|
||||
from ..lpips.loss.lpips import LPIPS
|
||||
from ..lpips.model.model import NLayerDiscriminator, weights_init
|
||||
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
|
||||
|
||||
|
||||
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
||||
|
||||
0
sgm/modules/autoencoding/lpips/__init__.py
Normal file
0
sgm/modules/autoencoding/lpips/__init__.py
Normal file
1
sgm/modules/autoencoding/lpips/loss/.gitignore
vendored
Normal file
1
sgm/modules/autoencoding/lpips/loss/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
vgg.pth
|
||||
23
sgm/modules/autoencoding/lpips/loss/LICENSE
Normal file
23
sgm/modules/autoencoding/lpips/loss/LICENSE
Normal file
@@ -0,0 +1,23 @@
|
||||
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
0
sgm/modules/autoencoding/lpips/loss/__init__.py
Normal file
0
sgm/modules/autoencoding/lpips/loss/__init__.py
Normal file
147
sgm/modules/autoencoding/lpips/loss/lpips.py
Normal file
147
sgm/modules/autoencoding/lpips/loss/lpips.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
|
||||
from ..util import get_ckpt_path
|
||||
|
||||
|
||||
class LPIPS(nn.Module):
|
||||
# Learned perceptual metric
|
||||
def __init__(self, use_dropout=True):
|
||||
super().__init__()
|
||||
self.scaling_layer = ScalingLayer()
|
||||
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
||||
self.net = vgg16(pretrained=True, requires_grad=False)
|
||||
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
||||
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
||||
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
||||
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
||||
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
||||
self.load_from_pretrained()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def load_from_pretrained(self, name="vgg_lpips"):
|
||||
ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
|
||||
self.load_state_dict(
|
||||
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
|
||||
)
|
||||
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, name="vgg_lpips"):
|
||||
if name != "vgg_lpips":
|
||||
raise NotImplementedError
|
||||
model = cls()
|
||||
ckpt = get_ckpt_path(name)
|
||||
model.load_state_dict(
|
||||
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
|
||||
)
|
||||
return model
|
||||
|
||||
def forward(self, input, target):
|
||||
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
||||
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
||||
feats0, feats1, diffs = {}, {}, {}
|
||||
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
||||
for kk in range(len(self.chns)):
|
||||
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
|
||||
outs1[kk]
|
||||
)
|
||||
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
||||
|
||||
res = [
|
||||
spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
|
||||
for kk in range(len(self.chns))
|
||||
]
|
||||
val = res[0]
|
||||
for l in range(1, len(self.chns)):
|
||||
val += res[l]
|
||||
return val
|
||||
|
||||
|
||||
class ScalingLayer(nn.Module):
|
||||
def __init__(self):
|
||||
super(ScalingLayer, self).__init__()
|
||||
self.register_buffer(
|
||||
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
|
||||
)
|
||||
self.register_buffer(
|
||||
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
|
||||
)
|
||||
|
||||
def forward(self, inp):
|
||||
return (inp - self.shift) / self.scale
|
||||
|
||||
|
||||
class NetLinLayer(nn.Module):
|
||||
"""A single linear layer which does a 1x1 conv"""
|
||||
|
||||
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
||||
super(NetLinLayer, self).__init__()
|
||||
layers = (
|
||||
[
|
||||
nn.Dropout(),
|
||||
]
|
||||
if (use_dropout)
|
||||
else []
|
||||
)
|
||||
layers += [
|
||||
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
|
||||
]
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
|
||||
class vgg16(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True):
|
||||
super(vgg16, self).__init__()
|
||||
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
self.N_slices = 5
|
||||
for x in range(4):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(4, 9):
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(9, 16):
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(16, 23):
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(23, 30):
|
||||
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu1_2 = h
|
||||
h = self.slice2(h)
|
||||
h_relu2_2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu3_3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu4_3 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5_3 = h
|
||||
vgg_outputs = namedtuple(
|
||||
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
|
||||
)
|
||||
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
||||
return out
|
||||
|
||||
|
||||
def normalize_tensor(x, eps=1e-10):
|
||||
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
|
||||
return x / (norm_factor + eps)
|
||||
|
||||
|
||||
def spatial_average(x, keepdim=True):
|
||||
return x.mean([2, 3], keepdim=keepdim)
|
||||
58
sgm/modules/autoencoding/lpips/model/LICENSE
Normal file
58
sgm/modules/autoencoding/lpips/model/LICENSE
Normal file
@@ -0,0 +1,58 @@
|
||||
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
--------------------------- LICENSE FOR pix2pix --------------------------------
|
||||
BSD License
|
||||
|
||||
For pix2pix software
|
||||
Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
----------------------------- LICENSE FOR DCGAN --------------------------------
|
||||
BSD License
|
||||
|
||||
For dcgan.torch software
|
||||
|
||||
Copyright (c) 2015, Facebook, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
||||
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
||||
|
||||
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
||||
|
||||
Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
0
sgm/modules/autoencoding/lpips/model/__init__.py
Normal file
0
sgm/modules/autoencoding/lpips/model/__init__.py
Normal file
88
sgm/modules/autoencoding/lpips/model/model.py
Normal file
88
sgm/modules/autoencoding/lpips/model/model.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import functools
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..util import ActNorm
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
elif classname.find("BatchNorm") != -1:
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.bias.data, 0)
|
||||
|
||||
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
"""Defines a PatchGAN discriminator as in Pix2Pix
|
||||
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
||||
"""
|
||||
|
||||
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
||||
"""Construct a PatchGAN discriminator
|
||||
Parameters:
|
||||
input_nc (int) -- the number of channels in input images
|
||||
ndf (int) -- the number of filters in the last conv layer
|
||||
n_layers (int) -- the number of conv layers in the discriminator
|
||||
norm_layer -- normalization layer
|
||||
"""
|
||||
super(NLayerDiscriminator, self).__init__()
|
||||
if not use_actnorm:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
else:
|
||||
norm_layer = ActNorm
|
||||
if (
|
||||
type(norm_layer) == functools.partial
|
||||
): # no need to use bias as BatchNorm2d has affine parameters
|
||||
use_bias = norm_layer.func != nn.BatchNorm2d
|
||||
else:
|
||||
use_bias = norm_layer != nn.BatchNorm2d
|
||||
|
||||
kw = 4
|
||||
padw = 1
|
||||
sequence = [
|
||||
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
]
|
||||
nf_mult = 1
|
||||
nf_mult_prev = 1
|
||||
for n in range(1, n_layers): # gradually increase the number of filters
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2**n, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(
|
||||
ndf * nf_mult_prev,
|
||||
ndf * nf_mult,
|
||||
kernel_size=kw,
|
||||
stride=2,
|
||||
padding=padw,
|
||||
bias=use_bias,
|
||||
),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
]
|
||||
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2**n_layers, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(
|
||||
ndf * nf_mult_prev,
|
||||
ndf * nf_mult,
|
||||
kernel_size=kw,
|
||||
stride=1,
|
||||
padding=padw,
|
||||
bias=use_bias,
|
||||
),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
]
|
||||
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
|
||||
] # output 1 channel prediction map
|
||||
self.main = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, input):
|
||||
"""Standard forward."""
|
||||
return self.main(input)
|
||||
128
sgm/modules/autoencoding/lpips/util.py
Normal file
128
sgm/modules/autoencoding/lpips/util.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
|
||||
|
||||
CKPT_MAP = {"vgg_lpips": "vgg.pth"}
|
||||
|
||||
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
|
||||
|
||||
|
||||
def download(url, local_path, chunk_size=1024):
|
||||
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
||||
with requests.get(url, stream=True) as r:
|
||||
total_size = int(r.headers.get("content-length", 0))
|
||||
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
||||
with open(local_path, "wb") as f:
|
||||
for data in r.iter_content(chunk_size=chunk_size):
|
||||
if data:
|
||||
f.write(data)
|
||||
pbar.update(chunk_size)
|
||||
|
||||
|
||||
def md5_hash(path):
|
||||
with open(path, "rb") as f:
|
||||
content = f.read()
|
||||
return hashlib.md5(content).hexdigest()
|
||||
|
||||
|
||||
def get_ckpt_path(name, root, check=False):
|
||||
assert name in URL_MAP
|
||||
path = os.path.join(root, CKPT_MAP[name])
|
||||
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
||||
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
||||
download(URL_MAP[name], path)
|
||||
md5 = md5_hash(path)
|
||||
assert md5 == MD5_MAP[name], md5
|
||||
return path
|
||||
|
||||
|
||||
class ActNorm(nn.Module):
|
||||
def __init__(
|
||||
self, num_features, logdet=False, affine=True, allow_reverse_init=False
|
||||
):
|
||||
assert affine
|
||||
super().__init__()
|
||||
self.logdet = logdet
|
||||
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
||||
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
||||
self.allow_reverse_init = allow_reverse_init
|
||||
|
||||
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
|
||||
|
||||
def initialize(self, input):
|
||||
with torch.no_grad():
|
||||
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
||||
mean = (
|
||||
flatten.mean(1)
|
||||
.unsqueeze(1)
|
||||
.unsqueeze(2)
|
||||
.unsqueeze(3)
|
||||
.permute(1, 0, 2, 3)
|
||||
)
|
||||
std = (
|
||||
flatten.std(1)
|
||||
.unsqueeze(1)
|
||||
.unsqueeze(2)
|
||||
.unsqueeze(3)
|
||||
.permute(1, 0, 2, 3)
|
||||
)
|
||||
|
||||
self.loc.data.copy_(-mean)
|
||||
self.scale.data.copy_(1 / (std + 1e-6))
|
||||
|
||||
def forward(self, input, reverse=False):
|
||||
if reverse:
|
||||
return self.reverse(input)
|
||||
if len(input.shape) == 2:
|
||||
input = input[:, :, None, None]
|
||||
squeeze = True
|
||||
else:
|
||||
squeeze = False
|
||||
|
||||
_, _, height, width = input.shape
|
||||
|
||||
if self.training and self.initialized.item() == 0:
|
||||
self.initialize(input)
|
||||
self.initialized.fill_(1)
|
||||
|
||||
h = self.scale * (input + self.loc)
|
||||
|
||||
if squeeze:
|
||||
h = h.squeeze(-1).squeeze(-1)
|
||||
|
||||
if self.logdet:
|
||||
log_abs = torch.log(torch.abs(self.scale))
|
||||
logdet = height * width * torch.sum(log_abs)
|
||||
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
||||
return h, logdet
|
||||
|
||||
return h
|
||||
|
||||
def reverse(self, output):
|
||||
if self.training and self.initialized.item() == 0:
|
||||
if not self.allow_reverse_init:
|
||||
raise RuntimeError(
|
||||
"Initializing ActNorm in reverse direction is "
|
||||
"disabled by default. Use allow_reverse_init=True to enable."
|
||||
)
|
||||
else:
|
||||
self.initialize(output)
|
||||
self.initialized.fill_(1)
|
||||
|
||||
if len(output.shape) == 2:
|
||||
output = output[:, :, None, None]
|
||||
squeeze = True
|
||||
else:
|
||||
squeeze = False
|
||||
|
||||
h = output / self.scale - self.loc
|
||||
|
||||
if squeeze:
|
||||
h = h.squeeze(-1).squeeze(-1)
|
||||
return h
|
||||
17
sgm/modules/autoencoding/lpips/vqperceptual.py
Normal file
17
sgm/modules/autoencoding/lpips/vqperceptual.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def hinge_d_loss(logits_real, logits_fake):
|
||||
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
||||
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
||||
d_loss = 0.5 * (loss_real + loss_fake)
|
||||
return d_loss
|
||||
|
||||
|
||||
def vanilla_d_loss(logits_real, logits_fake):
|
||||
d_loss = 0.5 * (
|
||||
torch.mean(torch.nn.functional.softplus(-logits_real))
|
||||
+ torch.mean(torch.nn.functional.softplus(logits_fake))
|
||||
)
|
||||
return d_loss
|
||||
@@ -1,7 +1,7 @@
|
||||
from .denoiser import Denoiser
|
||||
from .discretizer import Discretization
|
||||
from .loss import StandardDiffusionLoss
|
||||
from .model import Model, Encoder, Decoder
|
||||
from .model import Decoder, Encoder, Model
|
||||
from .openaimodel import UNetModel
|
||||
from .sampling import BaseDiffusionSampler
|
||||
from .wrappers import OpenAIWrapper
|
||||
|
||||
@@ -1,17 +1,29 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
|
||||
from ...util import append_zero
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...modules.diffusionmodules.util import make_beta_schedule
|
||||
from ...util import append_zero
|
||||
|
||||
|
||||
def generate_roughly_equally_spaced_steps(
|
||||
num_substeps: int, max_step: int
|
||||
) -> np.ndarray:
|
||||
return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
|
||||
|
||||
|
||||
class Discretization:
|
||||
def __call__(self, n, do_append_zero=True, device="cuda", flip=False):
|
||||
sigmas = self.get_sigmas(n, device)
|
||||
def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
|
||||
sigmas = self.get_sigmas(n, device=device)
|
||||
sigmas = append_zero(sigmas) if do_append_zero else sigmas
|
||||
return sigmas if not flip else torch.flip(sigmas, (0,))
|
||||
|
||||
@abstractmethod
|
||||
def get_sigmas(self, n, device):
|
||||
pass
|
||||
|
||||
|
||||
class EDMDiscretization(Discretization):
|
||||
def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
|
||||
@@ -19,7 +31,7 @@ class EDMDiscretization(Discretization):
|
||||
self.sigma_max = sigma_max
|
||||
self.rho = rho
|
||||
|
||||
def get_sigmas(self, n, device):
|
||||
def get_sigmas(self, n, device="cpu"):
|
||||
ramp = torch.linspace(0, 1, n, device=device)
|
||||
min_inv_rho = self.sigma_min ** (1 / self.rho)
|
||||
max_inv_rho = self.sigma_max ** (1 / self.rho)
|
||||
@@ -33,8 +45,8 @@ class LegacyDDPMDiscretization(Discretization):
|
||||
linear_start=0.00085,
|
||||
linear_end=0.0120,
|
||||
num_timesteps=1000,
|
||||
legacy_range=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_timesteps = num_timesteps
|
||||
betas = make_beta_schedule(
|
||||
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
|
||||
@@ -42,23 +54,15 @@ class LegacyDDPMDiscretization(Discretization):
|
||||
alphas = 1.0 - betas
|
||||
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
self.legacy_range = legacy_range
|
||||
|
||||
def get_sigmas(self, n, device):
|
||||
def get_sigmas(self, n, device="cpu"):
|
||||
if n < self.num_timesteps:
|
||||
c = self.num_timesteps // n
|
||||
|
||||
if self.legacy_range:
|
||||
timesteps = np.asarray(list(range(0, self.num_timesteps, c)))
|
||||
timesteps += 1 # Legacy LDM Hack
|
||||
else:
|
||||
timesteps = np.asarray(list(range(0, self.num_timesteps + 1, c)))
|
||||
timesteps -= 1
|
||||
timesteps = timesteps[1:]
|
||||
|
||||
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
|
||||
alphas_cumprod = self.alphas_cumprod[timesteps]
|
||||
else:
|
||||
elif n == self.num_timesteps:
|
||||
alphas_cumprod = self.alphas_cumprod
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
|
||||
sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
|
||||
@@ -3,9 +3,9 @@ from typing import List, Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import ListConfig
|
||||
from taming.modules.losses.lpips import LPIPS
|
||||
|
||||
from ...util import append_dims, instantiate_from_config
|
||||
from ...modules.autoencoding.lpips.loss.lpips import LPIPS
|
||||
|
||||
|
||||
class StandardDiffusionLoss(nn.Module):
|
||||
|
||||
@@ -30,5 +30,5 @@ class OpenAIWrapper(IdentityWrapper):
|
||||
timesteps=t,
|
||||
context=c.get("crossattn", None),
|
||||
y=c.get("vector", None),
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
|
||||
19
sgm/util.py
19
sgm/util.py
@@ -212,7 +212,6 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
|
||||
raise NotImplementedError
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
sd = pl_sd["state_dict"]
|
||||
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
||||
@@ -229,3 +228,21 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def get_configs_path() -> str:
|
||||
"""
|
||||
Get the `configs` directory.
|
||||
For a working copy, this is the one in the root of the repository,
|
||||
but for an installed copy, it's in the `sgm` package (see pyproject.toml).
|
||||
"""
|
||||
this_dir = os.path.dirname(__file__)
|
||||
candidates = (
|
||||
os.path.join(this_dir, "configs"),
|
||||
os.path.join(this_dir, "..", "configs"),
|
||||
)
|
||||
for candidate in candidates:
|
||||
candidate = os.path.abspath(candidate)
|
||||
if os.path.isdir(candidate):
|
||||
return candidate
|
||||
raise FileNotFoundError(f"Could not find SGM configs in {candidates}")
|
||||
|
||||
111
tests/inference/test_inference.py
Normal file
111
tests/inference/test_inference.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import numpy
|
||||
from PIL import Image
|
||||
import pytest
|
||||
from pytest import fixture
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from sgm.inference.api import (
|
||||
model_specs,
|
||||
SamplingParams,
|
||||
SamplingPipeline,
|
||||
Sampler,
|
||||
ModelArchitecture,
|
||||
)
|
||||
import sgm.inference.helpers as helpers
|
||||
|
||||
|
||||
@pytest.mark.inference
|
||||
class TestInference:
|
||||
@fixture(scope="class", params=model_specs.keys())
|
||||
def pipeline(self, request) -> SamplingPipeline:
|
||||
pipeline = SamplingPipeline(request.param)
|
||||
yield pipeline
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@fixture(
|
||||
scope="class",
|
||||
params=[
|
||||
[ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER],
|
||||
[ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER],
|
||||
],
|
||||
ids=["SDXL_V1", "SDXL_V0_9"],
|
||||
)
|
||||
def sdxl_pipelines(self, request) -> Tuple[SamplingPipeline, SamplingPipeline]:
|
||||
base_pipeline = SamplingPipeline(request.param[0])
|
||||
refiner_pipeline = SamplingPipeline(request.param[1])
|
||||
yield base_pipeline, refiner_pipeline
|
||||
del base_pipeline
|
||||
del refiner_pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def create_init_image(self, h, w):
|
||||
image_array = numpy.random.rand(h, w, 3) * 255
|
||||
image = Image.fromarray(image_array.astype("uint8")).convert("RGB")
|
||||
return helpers.get_input_image_tensor(image)
|
||||
|
||||
@pytest.mark.parametrize("sampler_enum", Sampler)
|
||||
def test_txt2img(self, pipeline: SamplingPipeline, sampler_enum):
|
||||
output = pipeline.text_to_image(
|
||||
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
||||
prompt="A professional photograph of an astronaut riding a pig",
|
||||
negative_prompt="",
|
||||
samples=1,
|
||||
)
|
||||
|
||||
assert output is not None
|
||||
|
||||
@pytest.mark.parametrize("sampler_enum", Sampler)
|
||||
def test_img2img(self, pipeline: SamplingPipeline, sampler_enum):
|
||||
output = pipeline.image_to_image(
|
||||
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
||||
image=self.create_init_image(pipeline.specs.height, pipeline.specs.width),
|
||||
prompt="A professional photograph of an astronaut riding a pig",
|
||||
negative_prompt="",
|
||||
samples=1,
|
||||
)
|
||||
assert output is not None
|
||||
|
||||
@pytest.mark.parametrize("sampler_enum", Sampler)
|
||||
@pytest.mark.parametrize(
|
||||
"use_init_image", [True, False], ids=["img2img", "txt2img"]
|
||||
)
|
||||
def test_sdxl_with_refiner(
|
||||
self,
|
||||
sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline],
|
||||
sampler_enum,
|
||||
use_init_image,
|
||||
):
|
||||
base_pipeline, refiner_pipeline = sdxl_pipelines
|
||||
if use_init_image:
|
||||
output = base_pipeline.image_to_image(
|
||||
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
||||
image=self.create_init_image(
|
||||
base_pipeline.specs.height, base_pipeline.specs.width
|
||||
),
|
||||
prompt="A professional photograph of an astronaut riding a pig",
|
||||
negative_prompt="",
|
||||
samples=1,
|
||||
return_latents=True,
|
||||
)
|
||||
else:
|
||||
output = base_pipeline.text_to_image(
|
||||
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
||||
prompt="A professional photograph of an astronaut riding a pig",
|
||||
negative_prompt="",
|
||||
samples=1,
|
||||
return_latents=True,
|
||||
)
|
||||
|
||||
assert isinstance(output, (tuple, list))
|
||||
samples, samples_z = output
|
||||
assert samples is not None
|
||||
assert samples_z is not None
|
||||
refiner_pipeline.refiner(
|
||||
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
||||
image=samples_z,
|
||||
prompt="A professional photograph of an astronaut riding a pig",
|
||||
negative_prompt="",
|
||||
samples=1,
|
||||
)
|
||||
Reference in New Issue
Block a user