100 Commits

Author SHA1 Message Date
Stephan Auerhahn
7ef5489cea Merge branch 'main' into helpers-fixes 2023-08-17 10:21:35 -07:00
Vitaly Bondar
477d8b9a77 fix EDMDiscretization sigma_min for correct sampling noise scheduling (#114) 2023-08-17 08:48:30 -07:00
Stephan Auerhahn
e289621992 fix reference 2023-08-12 13:52:46 -07:00
Stephan Auerhahn
2fc4680bf9 Easier default params 2023-08-12 13:22:04 -07:00
Stephan Auerhahn
e32972b85b remove extra init 2023-08-12 05:42:22 -07:00
Stephan Auerhahn
65c6ec1cec run black 2023-08-12 05:40:25 -07:00
Stephan Auerhahn
5fde7e73b8 set a default scale 2023-08-12 05:35:36 -07:00
Stephan Auerhahn
fbe93fc53b PR fixes, model specific defaults 2023-08-12 05:33:16 -07:00
Stephan Auerhahn
c0655731d5 fix streamlit inputs 2023-08-12 04:25:56 -07:00
Stephan Auerhahn
f6704532a0 abstract device defaults 2023-08-12 07:27:25 +00:00
Stephan Auerhahn
98c4b7753b cleanup imports in test 2023-08-12 07:16:02 +00:00
Stephan Auerhahn
d4307bef5d Test model device manager and fix bugs 2023-08-12 07:15:36 +00:00
Stephan Auerhahn
fe4632034b fix for orig dimensions 2023-08-11 16:31:53 -07:00
Stephan Auerhahn
d6f2b78994 pass options into state2 init 2023-08-10 15:06:55 -07:00
Stephan Auerhahn
cd81956241 text updates 2023-08-10 13:31:03 -07:00
Stephan Auerhahn
5c17043434 change default 2023-08-10 13:15:23 -07:00
Stephan Auerhahn
2aebc8882d split fp16 and swapping functionality 2023-08-10 13:14:38 -07:00
Stephan Auerhahn
3816aaa639 simplify device_manager usage 2023-08-10 13:05:30 -07:00
Stephan Auerhahn
88395261d8 update helpers 2023-08-10 12:45:37 -07:00
Stephan Auerhahn
b3866d1218 move checkbox out of cached resource 2023-08-10 12:44:48 -07:00
Stephan Auerhahn
a25662e969 low vram checkbox fix, remove magic strings 2023-08-10 12:40:32 -07:00
Stephan Auerhahn
26b10f56f3 fix missing index 2023-08-10 12:24:12 -07:00
Stephan Auerhahn
3e7ada70c5 fix autocast 2023-08-10 05:42:31 -07:00
Stephan Auerhahn
de7a627978 more fixes and cleanup 2023-08-10 05:11:34 -07:00
Stephan Auerhahn
9b18e6fa19 update api module 2023-08-10 05:07:22 -07:00
Stephan Auerhahn
47805f233c finish device manager refactor 2023-08-10 04:55:43 -07:00
Stephan Auerhahn
e190ecc60b path helper & model swapping rewrite 2023-08-10 04:35:59 -07:00
Stephan Auerhahn
fc498bfaef remove duplicate imports 2023-08-10 03:20:56 -07:00
Stephan Auerhahn
8011d54ca1 some PR fixes 2023-08-10 03:19:37 -07:00
Stephan Auerhahn
b51c36b0df extract path resolution method, fix/improve device swapping support 2023-08-09 19:31:59 -07:00
Stephan Auerhahn
d245e2002f more types 2023-08-09 13:46:06 -07:00
Stephan Auerhahn
725bea9f75 pull in import fix 2023-08-09 13:29:16 -07:00
Stephan Auerhahn
a009aa8a9f adding some typing 2023-08-09 13:27:30 -07:00
Stephan Auerhahn
f86ffac274 context manager 2023-08-09 12:38:44 -07:00
Stephan Auerhahn
a726ce3eb7 replace usage of get 2023-08-09 12:30:43 -07:00
Stephan Auerhahn
c4b7baf896 Streamlit refactor (#105)
* initial streamlit refactoring pass

* cleanup and fixes

* fix refiner strength

* Modify params correctly

* fix exception
2023-08-06 19:58:52 -07:00
Stephan Auerhahn
7e7fee3f0f system env var 2023-08-06 19:22:59 -07:00
Stephan Auerhahn
49fe53c165 use env var for sgm checkpoints path 2023-08-06 19:21:17 -07:00
Stephan Auerhahn
6c18c8443a rename ModelOnDevice to SwapToDevice 2023-08-06 23:46:20 +00:00
Stephan Auerhahn
ced97f0e84 update defaults 2023-08-06 23:24:14 +00:00
Stephan Auerhahn
76ca428422 fix path resolution bug 2023-08-06 21:39:18 +00:00
Stephan Auerhahn
8f8757b4ff version bump for changes to inference helpers 2023-08-06 21:09:09 +00:00
Stephan Auerhahn
f2fba1dfa2 fix noisy latent handling 2023-08-06 21:08:19 +00:00
Stephan Auerhahn
451c76ada1 format 2023-08-06 12:26:16 +00:00
Stephan Auerhahn
0c2c5c66a2 fix device check 2023-08-06 12:26:01 +00:00
Stephan Auerhahn
ea5f232d5d move conditioner to device 2023-08-06 11:42:39 +00:00
Stephan Auerhahn
f06c67c206 formatting, remove reference 2023-08-06 11:30:40 +00:00
Stephan Auerhahn
b216934b7e align with streamlit helpers and re-de-deuplicate 2023-08-06 11:20:22 +00:00
Stephan Auerhahn
77d0e27747 format 2023-08-03 17:57:55 -07:00
Stephan Auerhahn
4aea6fa2a4 Fix checkpoint loading too 2023-08-03 17:56:24 -07:00
Stephan Auerhahn
84d3a7f6f5 fix fallback logic for config path 2023-08-03 17:50:10 -07:00
Stephan Auerhahn
19fa4da3de run black again 2023-08-04 00:16:29 +00:00
Stephan Auerhahn
4e2236f67d Fix path logic for development installs 2023-08-04 00:15:22 +00:00
Stephan Auerhahn
baf79d2d79 black 2023-08-04 00:00:51 +00:00
Stephan Auerhahn
44943df4f2 Allow loading custom models and improve path logic 2023-08-03 23:59:42 +00:00
Stephan Auerhahn
73287ec3a3 Extract method for img2img wrapper 2023-08-03 23:42:11 +00:00
Stephan Auerhahn
853adb4022 Add defaults to refiner function 2023-08-03 12:50:23 -07:00
Stephan Auerhahn
45feb6cb9c Use wrapper correctly in refiner helper 2023-08-02 23:14:30 +00:00
Stephan Auerhahn
45c443b316 Fix license-files setting for project (#71) 2023-07-26 20:14:23 +02:00
Jonas Müller
dea60596fc Model hashes (#70)
* Added model hashes

* Fix link
2023-07-26 20:03:48 +02:00
Stephan Auerhahn
299abbcd90 Use final v1 filename (#67) 2023-07-26 19:53:19 +02:00
Jonas Müller
e5d714d304 Improved sampling (#69)
* New research features

* Add new model specs
---------

Co-authored-by: Dominik Lorenz <53151171+qp-qp@users.noreply.github.com>

* remove sd1.5 and change default refiner to 1.0

* remove asking second time for output

* adapt model names

* adjusted strength

* Correctly pass prompt

---------

Co-authored-by: Dominik Lorenz <53151171+qp-qp@users.noreply.github.com>
2023-07-26 19:49:23 +02:00
Robin Rombach
f2fa96b7e5 README updates for SDXL 1.0 release (#68)
* remove sdxl report from github repo, point to arxiv instead

* update licenses and add teaser img

* update readme for SDXL 1.0 release
2023-07-26 19:24:52 +02:00
Aarni Koskela
c60c091f4d Move CODEOWNERS so it has an effect (#66)
* Move CODEOWNERS so it has an effect

* CI: use vars.SGM_CHECKPOINTS_PATH
2023-07-26 08:52:59 -07:00
Stephan Auerhahn
931d7a389a Add inference helpers & tests (#57)
* Add inference helpers & tests

* Support testing with hatch

* fixes to hatch script

* add inference test action

* change workflow trigger

* widen trigger to test

* revert changes to workflow triggers

* Install local python in action

* Trigger on push again

* fix python version

* add CODEOWNERS and change triggers

* Report tests results

* update action versions

* format

* Fix typo and add refiner helper

* use a shared path loaded from a secret for checkpoints source

* typo fix

* Use device from input and remove duplicated code

* PR feedback

* fix call to load_model_from_config

* Move model to gpu

* Refactor helpers

* cleanup

* test refiner, prep for 1.0, align with metadata

* fix paths on second load

* deduplicate streamlit code

* filenames

* fixes

* add pydantic to requirements

* fix usage of `msg` in demo script

* remove double text

* run black

* fix streamlit sampling when returning latents

* extract function for streamlit output

* another fix for streamlit outputs

* fix img2img in streamlit

* Make fp16 optional and fix device param

* PR feedback

* fix dict cast for dataclass

* run black, update ci script

* cache pip dependencies on hosted runners, remove extra runs

* install package in ci env

* fix cache path

* PR cleanup

* one more cleanup

* don't cache, it filled up
2023-07-26 04:37:24 -07:00
Benjamin Aubin
e596332148 Pre release changes for production (#59)
* clean requirements

* rm taming deps

* isort, black

* mv lipips, license

* clean vq, fix path

* fix loss path, gitignore

* tested requirements pt13

* fix numpy req for python3.8, add tests

* fix name

* fix dep scipy 3.8 pt2

* add black test formatter
2023-07-26 12:09:28 +02:00
Jonas Müller
4a3f0f546e Revert "Replace most print()s with logging calls (#42)" (#65)
This reverts commit 6f6d3f8716.
2023-07-26 10:30:21 +02:00
Tim Dockhorn
7934245835 Revert "Minimize re-exports from __init__ files (#44)" (#63)
This reverts commit 57862fb4c7.
2023-07-26 10:26:28 +02:00
Tim Dockhorn
1da250906d Revert "Dead code removal (#48)" (#62)
This reverts commit b5b5680150.
2023-07-26 10:26:00 +02:00
Tim Dockhorn
a4ceca6d03 Revert "fall back to vanilla if xformers is not available (#51)" (#61)
This reverts commit ef520df1db.
2023-07-26 10:25:17 +02:00
ablattmann
68f3f89bd3 Fix crashing line in logging in sgm/models/diffusion.py (#64) 2023-07-26 10:13:42 +02:00
Aarni Koskela
57862fb4c7 Minimize re-exports from __init__ files (#44)
This allows importing parts of the package without having to
import practically everything (since importing a package will
import its parents' __init__s, etc).
2023-07-25 16:24:09 +02:00
Aarni Koskela
ef520df1db fall back to vanilla if xformers is not available (#51) 2023-07-25 16:21:51 +02:00
Aarni Koskela
2897fdc99a Move attention testing/benchmarking code out of package (#47) 2023-07-25 15:40:22 +02:00
Aarni Koskela
b5b5680150 Dead code removal (#48)
* Remove old commented-out attention code

* Mark two functions as likely unused

* Use exists() and default() from sgm.util
2023-07-25 15:24:24 +02:00
Aarni Koskela
6f6d3f8716 Replace most print()s with logging calls (#42) 2023-07-25 15:21:30 +02:00
Luca Antiga
6ecd0a900a Fix link (#24) 2023-07-25 09:58:18 +02:00
Jonas Müller
e25e4c0df1 Merge pull request #43 from akx/fix-safetensors-load
Fix loading safetensors with load_model_from_config
2023-07-21 18:00:53 +02:00
Aarni Koskela
e5dc9669ed Set up Python packaging (#17)
* Sort .gitignore; add dist and *.py[cod]

* Use pyproject.toml + Hatch instead of setup.py

Sibling of https://github.com/Stability-AI/stablediffusion/pull/269

* Add packaging documentation
2023-07-18 13:06:05 +02:00
Aarni Koskela
48904a692d Fix loading safetensors with load_model_from_config
Previously, the `sd` from the safetensors if branch was not used at all, and `pl_sd` would have not been assigned.
2023-07-17 09:56:35 +03:00
Tim
5c10deee76 Merge branch 'main' of https://github.com/Stability-AI/generative-models into main 2023-07-09 10:40:26 -07:00
Tim
89f5413e6d Getting rid of unnecessary error message 2023-07-09 10:40:18 -07:00
Tim
ba3e7fed5a Fixing additional GPU memory on device 0 due to discretization 2023-07-09 10:40:09 -07:00
Tim Dockhorn
ea89ce793d Merge pull request #28 from jenuk/fix-samples_z
Fix `samples_z` undefined
2023-07-09 10:35:04 -07:00
jenuk
7b1978e055 Only do refiner step if samples are actually available 2023-07-07 07:48:21 +00:00
Jonas Müller
9d5ace911e Merge pull request #25 from pharmapsychotic/bugfix/watermark
Fix channel ordering RGB to cv2 BGR
2023-07-06 16:12:40 +02:00
pharmapsychotic
95b9acc5c6 Reformat with black 2023-07-06 09:03:23 -05:00
pharmapsychotic
5df4d9893c Watermark encoder expects images in BGR channel order (matching cv2 imread). This fix reduces the watermark artifacts. 2023-07-05 12:05:14 -05:00
Robin Rombach
ae18ba3e87 add sdxl report 2023-07-04 13:21:13 +02:00
Tim
061d11d55d Fixing validation step for PL 1 2023-06-30 13:40:24 -07:00
Tim
2796c81a5f Making spacing function slimmer 2023-06-30 11:17:57 -07:00
Tim
e9869d7822 Changed LegacyDDPMDiscretization for sampling 2023-06-30 11:17:46 -07:00
Tim Dockhorn
613af104c6 Merge pull request #14 from patrickvonplaten/patch-1
Update sampling.py
2023-06-29 16:17:03 -07:00
Patrick von Platen
376cec3b0f Update sampling.py
Correct typo
2023-06-26 14:37:10 +02:00
Bryce Drennan
76e549dd94 Add missing init files (#11) 2023-06-26 09:41:25 +02:00
Ikko Eltociear Ashimine
5f0a2fcf48 Update README.md (#9) 2023-06-26 09:40:12 +02:00
Balanagireddy M
d8a6a97fb0 Minor spellings (#12) 2023-06-26 09:38:55 +02:00
Tim
a1af4ac4f1 Adapting txt logging for python 3.8 2023-06-25 16:42:45 -07:00
Robin Rombach
58ddbee3ee Update README.md 2023-06-22 10:48:49 -07:00
Robin Rombach
bec98beff8 Update README.md 2023-06-22 10:48:18 -07:00
53 changed files with 2810 additions and 1197 deletions

15
.github/workflows/black.yml vendored Normal file
View File

@@ -0,0 +1,15 @@
name: Run black
on: [pull_request]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install venv
run: |
sudo apt-get -y install python3.10-venv
- uses: psf/black@stable
with:
options: "--check --verbose -l88"
src: "./sgm ./scripts ./main.py"

27
.github/workflows/test-build.yaml vendored Normal file
View File

@@ -0,0 +1,27 @@
name: Build package
on:
push:
branches: [ main ]
pull_request:
jobs:
build:
name: Build
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.10"]
requirements-file: ["pt2", "pt13"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements/${{ matrix.requirements-file }}.txt
pip install .

34
.github/workflows/test-inference.yml vendored Normal file
View File

@@ -0,0 +1,34 @@
name: Test inference
on:
pull_request:
push:
branches:
- main
jobs:
test:
name: "Test inference"
# This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment
if: github.repository == 'stability-ai/generative-models'
runs-on: [self-hosted, slurm, g40]
steps:
- uses: actions/checkout@v3
- name: "Symlink checkpoints"
run: ln -s $SGM_CHECKPOINTS checkpoints
- name: "Setup python"
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: "Install Hatch"
run: pip install hatch
- name: "Run inference tests"
run: hatch run ci:test-inference --junit-xml test-results.xml
- name: Surface failing tests
if: always()
uses: pmeier/pytest-results-action@main
with:
path: test-results.xml
summary: true
display-options: fEX
fail-on-empty: true

17
.gitignore vendored
View File

@@ -1,7 +1,14 @@
.pt2
.pt2_2
.pt13
# extensions
*.egg-info
build
*.py[cod]
# envs
.pt13
.pt2
# directories
/checkpoints
/dist
/outputs
/checkpoints
/build
/src

1
CODEOWNERS Normal file
View File

@@ -0,0 +1 @@
.github @Stability-AI/infrastructure

21
LICENSE-CODE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Stability AI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

120
README.md
View File

@@ -4,14 +4,29 @@
## News
**July 26, 2023**
- We are releasing two new open models with a permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file hashes):
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version over `SDXL-base-0.9`.
- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version over `SDXL-refiner-0.9`.
![sample2](assets/001_with_eval.png)
**July 4, 2023**
- A technical report on SDXL is now available [here](https://arxiv.org/abs/2307.01952).
**June 22, 2023**
- We are releasing two new diffusion models:
- `SD-XL 0.9-base`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip) and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses the OpenCLIP model.
- `SD-XL 0.9-refiner`: The refiner has been trained to denoise small noise levels of high quality data and as such is not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
- We are releasing two new diffusion models for research purposes:
- `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip) and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses the OpenCLIP model.
- `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
**We plan to do a full release soon (July).**
If you would like to access these models for your research, please apply using one of the following links:
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
This means that you can apply for any of the two links - and if you are granted - you can access both.
Please log in to your Hugging Face Account with your organization email to request access.
**We plan to do a full release soon (July).**
## The codebase
@@ -21,7 +36,7 @@ Modularity is king. This repo implements a config-driven approach where we build
### Changelog from the old `ldm` codebase
For training, we use [pytorch-lightning](https://www.pytorchlightning.ai/index.html), but it should be easy to use other training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`, now `DiffusionEngine`) has been cleaned up:
For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`, now `DiffusionEngine`) has been cleaned up:
- No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial conditionings, and all combinations thereof) in a single class: `GeneralConditioner`, see `sgm/modules/encoders/modules.py`.
- We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
@@ -43,47 +58,98 @@ cd generative-models
#### 2. Setting up the virtualenv
This is assuming you have navigated to the `generative-models` root after cloning it.
This is assuming you have navigated to the `generative-models` root after cloning it.
**NOTE:** This is tested under `python3.8` and `python3.10`. For other python versions, you might encounter version conflicts.
**PyTorch 1.13**
**PyTorch 1.13**
```shell
# install required packages from pypi
python3 -m venv .pt1
source .pt1/bin/activate
pip3 install wheel
pip3 install -r requirements_pt13.txt
python3 -m venv .pt13
source .pt13/bin/activate
pip3 install -r requirements/pt13.txt
```
**PyTorch 2.0**
**PyTorch 2.0**
```shell
# install required packages from pypi
python3 -m venv .pt2
source .pt2/bin/activate
pip3 install wheel
pip3 install -r requirements_pt2.txt
pip3 install -r requirements/pt2.txt
```
## Inference:
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`. The following models are currently supported:
- [SD-XL 0.9-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)
- [SD-XL 0.9-refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
- [SD 2.1-512](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.safetensors)
- [SD 2.1-768](https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors)
#### 3. Install `sgm`
```shell
pip3 install .
```
#### 4. Install `sdata` for training
```shell
pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
```
## Packaging
This repository uses PEP 517 compliant packaging using [Hatch](https://hatch.pypa.io/latest/).
To build a distributable wheel, install `hatch` and run `hatch build`
(specifying `-t wheel` will skip building a sdist, which is not necessary).
```
pip install hatch
hatch build -t wheel
```
You will find the built package in `dist/`. You can install the wheel with `pip install dist/*.whl`.
Note that the package does **not** currently specify dependencies; you will need to install the required packages,
depending on your use case and PyTorch version, manually.
## Inference
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`.
We provide file hashes for the complete file as well as for only the saved tensors in the file (see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
The following models are currently supported:
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
```
File Hash (sha256): 31e35c80fc4829d14f90153f4c74cd59c90b779f6afe05a74cd6120b893f7e5b
Tensordata Hash (sha256): 0xd7a9105a900fd52748f20725fe52fe52b507fd36bee4fc107b1550a26e6ee1d7
```
- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)
```
File Hash (sha256): 7440042bbdc8a24813002c09b6b69b64dc90fded4472613437b7f55f9b7d9c5f
Tensordata Hash (sha256): 0x1a77d21bebc4b4de78c474a90cb74dc0d2217caf4061971dbfa75ad406b75d81
```
- [SDXL-base-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)
- [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
- [SD-2.1-512](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.safetensors)
- [SD-2.1-768](https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors)
**Weights for SDXL**:
If you would like to access these models for your research, please apply using one of the following links:
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
This means that you can apply for any of the two links - and if you are granted - you can access both.
Please log in to your HuggingFace Account with your organization email to request access.
After obtaining the weights, place them into `checkpoints/`.
**SDXL-1.0:**
The weights of SDXL-1.0 are available (subject to a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here:
- base model: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/
- refiner model: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/
**SDXL-0.9:**
The weights of SDXL-0.9 are available and subject to a [research license](model_licenses/LICENSE-SDXL0.9).
If you would like to access these models for your research, please apply using one of the following links:
[SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
This means that you can apply for any of the two links - and if you are granted - you can access both.
Please log in to your Hugging Face Account with your organization email to request access.
After obtaining the weights, place them into `checkpoints/`.
Next, start the demo using
```
@@ -137,7 +203,7 @@ run
python main.py --base configs/example_training/toy/mnist_cond.yaml
```
**NOTE 1:** Using the non-toy-dataset configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml` and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depdending on the used dataset (which is expected to stored in tar-file in the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search for comments containing `USER:` in the respective config.
**NOTE 1:** Using the non-toy-dataset configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml` and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the used dataset (which is expected to stored in tar-file in the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search for comments containing `USER:` in the respective config.
**NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`, only `pytorch1.13` is supported.
@@ -174,7 +240,7 @@ guidance.
### Dataset Handling
For large scale training we recommend using the datapipelines from our [datapipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement and automatically included when following the steps from the [Installation section](#installation).
For large scale training we recommend using the data pipelines from our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement and automatically included when following the steps from the [Installation section](#installation).
Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
data keys/values,
e.g.,

BIN
assets/001_with_eval.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.0 MiB

14
main.py
View File

@@ -12,22 +12,18 @@ import pytorch_lightning as pl
import torch
import torchvision
import wandb
from PIL import Image
from matplotlib import pyplot as plt
from natsort import natsorted
from omegaconf import OmegaConf
from packaging import version
from PIL import Image
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_only
from sgm.util import (
exists,
instantiate_from_config,
isheatmap,
)
from sgm.util import exists, instantiate_from_config, isheatmap
MULTINODE_HACKS = True
@@ -469,9 +465,8 @@ class ImageLogger(Callback):
self.log_img(pl_module, batch, batch_idx, split="train")
@rank_zero_only
# def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, **kwargs
self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs
):
if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split="val")
@@ -911,11 +906,12 @@ if __name__ == "__main__":
trainer.test(model, data)
except RuntimeError as err:
if MULTINODE_HACKS:
import requests
import datetime
import os
import socket
import requests
device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
hostname = socket.gethostname()
ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")

View File

@@ -0,0 +1,175 @@
Copyright (c) 2023 Stability AI CreativeML Open RAIL++-M License dated July 26, 2023
Section I: PREAMBLE Multimodal generative models are being widely adopted and used, and
have the potential to transform the way artists, among other individuals, conceive and
benefit from AI or ML technologies as a tool for content creation. Notwithstanding the
current and potential benefits that these artifacts can bring to society at large, there
are also concerns about potential misuses of them, either due to their technical
limitations or ethical considerations. In short, this license strives for both the open
and responsible downstream use of the accompanying model. When it comes to the open
character, we took inspiration from open source permissive licenses regarding the grant
of IP rights. Referring to the downstream responsible use, we added use-based
restrictions not permitting the use of the model in very specific scenarios, in order
for the licensor to be able to enforce the license in case potential misuses of the
Model may occur. At the same time, we strive to promote open and responsible research on
generative models for art and content generation. Even though downstream derivative
versions of the model could be released under different licensing terms, the latter will
always have to include - at minimum - the same use-based restrictions as the ones in the
original license (this license). We believe in the intersection between open and
responsible AI development; thus, this agreement aims to strike a balance between both
in order to enable responsible open-science in the field of AI. This CreativeML Open
RAIL++-M License governs the use of the model (and its derivatives) and is informed by
the model card associated with the model. NOW THEREFORE, You and Licensor agree as
follows: Definitions "License" means the terms and conditions for use, reproduction, and
Distribution as defined in this document. "Data" means a collection of information
and/or content extracted from the dataset used with the Model, including to train,
pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
"Output" means the results of operating a Model as embodied in informational content
resulting therefrom. "Model" means any accompanying machine-learning based assemblies
(including checkpoints), consisting of learnt weights, parameters (including optimizer
states), corresponding to the model architecture as embodied in the Complementary
Material, that have been trained or tuned, in whole or in part on the Data, using the
Complementary Material. "Derivatives of the Model" means all modifications to the Model,
works based on the Model, or any other model which is created or initialized by transfer
of patterns of the weights, parameters, activations or output of the Model, to the other
model, in order to cause the other model to perform similarly to the Model, including -
but not limited to - distillation methods entailing the use of intermediate data
representations or methods based on the generation of synthetic data by the Model for
training the other model. "Complementary Material" means the accompanying source code
and scripts used to define, run, load, benchmark or evaluate the Model, and used to
prepare data for training or evaluation, if any. This includes any accompanying
documentation, tutorials, examples, etc, if any. "Distribution" means any transmission,
reproduction, publication or other sharing of the Model or Derivatives of the Model to a
third party, including providing the Model as a hosted service made available by
electronic or other remote means - e.g. API-based or web access. "Licensor" means the
copyright owner or entity authorized by the copyright owner that is granting the
License, including the persons or entities that may have rights in the Model and/or
distributing the Model. "You" (or "Your") means an individual or Legal Entity exercising
permissions granted by this License and/or making use of the Model for whichever purpose
and in any field of use, including usage of the Model in an end-use application - e.g.
chatbot, translator, image generator. "Third Parties" means individuals or legal
entities that are not under common control with Licensor or You. "Contribution" means
any work of authorship, including the original version of the Model and any
modifications or additions to that Model or Derivatives of the Model thereof, that is
intentionally submitted to Licensor for inclusion in the Model by the copyright owner or
by an individual or Legal Entity authorized to submit on behalf of the copyright owner.
For the purposes of this definition, "submitted" means any form of electronic, verbal,
or written communication sent to the Licensor or its representatives, including but not
limited to communication on electronic mailing lists, source code control systems, and
issue tracking systems that are managed by, or on behalf of, the Licensor for the
purpose of discussing and improving the Model, but excluding communication that is
conspicuously marked or otherwise designated in writing by the copyright owner as "Not a
Contribution." "Contributor" means Licensor and any individual or Legal Entity on behalf
of whom a Contribution has been received by Licensor and subsequently incorporated
within the Model.
Section II: INTELLECTUAL PROPERTY RIGHTS Both copyright and patent grants apply to the
Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of
the Model are subject to additional terms as described in
Section III. Grant of Copyright License. Subject to the terms and conditions of this
License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly
display, publicly perform, sublicense, and distribute the Complementary Material, the
Model, and Derivatives of the Model. Grant of Patent License. Subject to the terms and
conditions of this License and where and as applicable, each Contributor hereby grants
to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this paragraph) patent license to make, have made, use, offer to
sell, sell, import, and otherwise transfer the Model and the Complementary Material,
where such license applies only to those patent claims licensable by such Contributor
that are necessarily infringed by their Contribution(s) alone or by combination of their
Contribution(s) with the Model to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a cross-claim or counterclaim
in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution
incorporated within the Model and/or Complementary Material constitutes direct or
contributory patent infringement, then any patent licenses granted to You under this
License for the Model and/or Work shall terminate as of the date such litigation is
asserted or filed. Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
Distribution and Redistribution. You may host for Third Party remote access purposes
(e.g. software-as-a-service), reproduce and distribute copies of the Model or
Derivatives of the Model thereof in any medium, with or without modifications, provided
that You meet the following conditions: Use-based restrictions as referenced in
paragraph 5 MUST be included as an enforceable provision by You in any type of legal
agreement (e.g. a license) governing the use and/or distribution of the Model or
Derivatives of the Model, and You shall give notice to subsequent users You Distribute
to, that the Model or Derivatives of the Model are subject to paragraph 5. This
provision does not apply to the use of Complementary Material. You must give any Third
Party recipients of the Model or Derivatives of the Model a copy of this License; You
must cause any modified files to carry prominent notices stating that You changed the
files; You must retain all copyright, patent, trademark, and attribution notices
excluding those notices that do not pertain to any part of the Model, Derivatives of the
Model. You may add Your own copyright statement to Your modifications and may provide
additional or different license terms and conditions - respecting paragraph 4.a. - for
use, reproduction, or Distribution of Your modifications, or for any such Derivatives of
the Model as a whole, provided Your use, reproduction, and Distribution of the Model
otherwise complies with the conditions stated in this License. Use-based restrictions.
The restrictions set forth in Attachment A are considered Use-based restrictions.
Therefore You cannot use the Model and the Derivatives of the Model for the specified
restricted uses. You may use the Model subject to this License, including only for
lawful purposes and in accordance with the License. Use may include creating any content
with, finetuning, updating, running, training, evaluating and/or reparametrizing the
Model. You shall require all of Your users who use the Model or a Derivative of the
Model to comply with the terms of this paragraph (paragraph 5). The Output You Generate.
Except as set forth herein, Licensor claims no rights in the Output You generate using
the Model. You are accountable for the Output you generate and its subsequent uses. No
use of the output can contravene any provision as stated in the License.
Section IV: OTHER PROVISIONS Updates and Runtime Restrictions. To the maximum extent
permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage
of the Model in violation of this License. Trademarks and related. Nothing in this
License permits You to make use of Licensors trademarks, trade names, logos or to
otherwise suggest endorsement or misrepresent the relationship between the parties; and
any rights not expressly granted herein are reserved by the Licensors. Disclaimer of
Warranty. Unless required by applicable law or agreed to in writing, Licensor provides
the Model and the Complementary Material (and each Contributor provides its
Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
express or implied, including, without limitation, any warranties or conditions of
TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are
solely responsible for determining the appropriateness of using or redistributing the
Model, Derivatives of the Model, and the Complementary Material and assume any risks
associated with Your exercise of permissions under this License. Limitation of
Liability. In no event and under no legal theory, whether in tort (including
negligence), contract, or otherwise, unless required by applicable law (such as
deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special, incidental, or
consequential damages of any character arising as a result of this License or out of the
use or inability to use the Model and the Complementary Material (including but not
limited to damages for loss of goodwill, work stoppage, computer failure or malfunction,
or any and all other commercial damages or losses), even if such Contributor has been
advised of the possibility of such damages. Accepting Warranty or Additional Liability.
While redistributing the Model, Derivatives of the Model and the Complementary Material
thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty,
indemnity, or other liability obligations and/or rights consistent with this License.
However, in accepting such obligations, You may act only on Your own behalf and on Your
sole responsibility, not on behalf of any other Contributor, and only if You agree to
indemnify, defend, and hold each Contributor harmless for any liability incurred by, or
claims asserted against, such Contributor by reason of your accepting any such warranty
or additional liability. If any provision of this License is held to be invalid, illegal
or unenforceable, the remaining provisions shall be unaffected thereby and remain valid
as if such provision had not been set forth herein.
END OF TERMS AND CONDITIONS
Attachment A Use Restrictions
You agree not to use the Model or Derivatives of the Model:
In any way that violates any applicable national, federal, state, local or
international law or regulation; For the purpose of exploiting, harming or attempting to
exploit or harm minors in any way; To generate or disseminate verifiably false
information and/or content with the purpose of harming others; To generate or
disseminate personal identifiable information that can be used to harm an individual; To
defame, disparage or otherwise harass others; For fully automated decision making that
adversely impacts an individuals legal rights or otherwise creates or modifies a
binding, enforceable obligation; For any use intended to or which has the effect of
discriminating against or harming individuals or groups based on online or offline
social behavior or known or predicted personal or personality characteristics; To
exploit any of the vulnerabilities of a specific group of persons based on their age,
social, physical or mental characteristics, in order to materially distort the behavior
of a person pertaining to that group in a manner that causes or is likely to cause that
person or another person physical or psychological harm; For any use intended to or
which has the effect of discriminating against individuals or groups based on legally
protected characteristics or categories; To provide medical advice and medical results
interpretation; To generate or disseminate information for the purpose to be used for
administration of justice, law enforcement, immigration or asylum processes, such as
predicting an individual will commit fraud/crime commitment (e.g. by text profiling,
drawing causal relationships between assertions made in documents, indiscriminate and
arbitrarily-targeted use).

48
pyproject.toml Normal file
View File

@@ -0,0 +1,48 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "sgm"
dynamic = ["version"]
description = "Stability Generative Models"
readme = "README.md"
license-files = { paths = ["LICENSE-CODE"] }
requires-python = ">=3.8"
[project.urls]
Homepage = "https://github.com/Stability-AI/generative-models"
[tool.hatch.version]
path = "sgm/__init__.py"
[tool.hatch.build]
# This needs to be explicitly set so the configuration files
# grafted into the `sgm` directory get included in the wheel's
# RECORD file.
include = [
"sgm",
]
# The force-include configurations below make Hatch copy
# the configs/ directory (containing the various YAML files required
# to generatively model) into the source distribution and the wheel.
[tool.hatch.build.targets.sdist.force-include]
"./configs" = "sgm/configs"
[tool.hatch.build.targets.wheel.force-include]
"./configs" = "sgm/configs"
[tool.hatch.envs.ci]
skip-install = false
dependencies = [
"pytest"
]
[tool.hatch.envs.ci.scripts]
test-inference = [
"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
"pip install -r requirements/pt2.txt",
"pytest -v tests/inference {args}",
]

3
pytest.ini Normal file
View File

@@ -0,0 +1,3 @@
[pytest]
markers =
inference: mark as inference test (deselect with '-m "not inference"')

40
requirements/pt13.txt Normal file
View File

@@ -0,0 +1,40 @@
black==23.7.0
chardet>=5.1.0
clip @ git+https://github.com/openai/CLIP.git
einops>=0.6.1
fairscale>=0.4.13
fire>=0.5.0
fsspec>=2023.6.0
invisible-watermark>=0.2.0
kornia==0.6.9
matplotlib>=3.7.2
natsort>=8.4.0
numpy>=1.24.4
omegaconf>=2.3.0
onnx<=1.12.0
open-clip-torch>=2.20.0
opencv-python==4.6.0.66
pandas>=2.0.3
pillow>=9.5.0
pudb>=2022.1.3
pytorch-lightning==1.8.5
pyyaml>=6.0.1
scipy>=1.10.1
streamlit>=1.25.0
tensorboardx==2.5.1
timm>=0.9.2
tokenizers==0.12.1
--extra-index-url https://download.pytorch.org/whl/cu117
torch==1.13.1+cu117
torchaudio==0.13.1
torchdata==0.5.1
torchmetrics>=1.0.1
torchvision==0.14.1+cu117
tqdm>=4.65.0
transformers==4.19.1
triton==2.0.0.post1
urllib3<1.27,>=1.25.4
wandb>=0.15.6
webdataset>=0.2.33
wheel>=0.41.0
xformers==0.0.16

39
requirements/pt2.txt Normal file
View File

@@ -0,0 +1,39 @@
black==23.7.0
chardet==5.1.0
clip @ git+https://github.com/openai/CLIP.git
einops>=0.6.1
fairscale>=0.4.13
fire>=0.5.0
fsspec>=2023.6.0
invisible-watermark>=0.2.0
kornia==0.6.9
matplotlib>=3.7.2
natsort>=8.4.0
ninja>=1.11.1
numpy>=1.24.4
omegaconf>=2.3.0
open-clip-torch>=2.20.0
opencv-python==4.6.0.66
pandas>=2.0.3
pillow>=9.5.0
pudb>=2022.1.3
pytorch-lightning==2.0.1
pyyaml>=6.0.1
scipy>=1.10.1
streamlit>=0.73.1
tensorboardx==2.6
timm>=0.9.2
tokenizers==0.12.1
torch>=2.0.1
torchaudio>=2.0.2
torchdata==0.6.1
torchmetrics>=1.0.1
torchvision>=0.15.2
tqdm>=4.65.0
transformers==4.19.1
triton==2.0.0
urllib3<1.27,>=1.25.4
wandb>=0.15.6
webdataset>=0.2.33
wheel>=0.41.0
xformers>=0.0.20

View File

@@ -1,41 +0,0 @@
omegaconf
einops
fire
tqdm
pillow
numpy
webdataset>=0.2.33
--extra-index-url https://download.pytorch.org/whl/cu117
torch==1.13.1+cu117
xformers==0.0.16
torchaudio==0.13.1
torchvision==0.14.1+cu117
torchmetrics
opencv-python==4.6.0.66
fairscale
pytorch-lightning==1.8.5
fsspec
kornia==0.6.9
matplotlib
natsort
tensorboardx==2.5.1
open-clip-torch
chardet
scipy
pandas
pudb
pyyaml
urllib3<1.27,>=1.25.4
streamlit>=0.73.1
timm
tokenizers==0.12.1
torchdata==0.5.1
transformers==4.19.1
onnx<=1.12.0
triton
wandb
invisible-watermark
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
-e git+https://github.com/openai/CLIP.git@main#egg=clip
-e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
-e .

View File

@@ -1,41 +0,0 @@
omegaconf
einops
fire
tqdm
pillow
numpy
webdataset>=0.2.33
ninja
torch
matplotlib
torchaudio>=2.0.2
torchmetrics
torchvision>=0.15.2
opencv-python==4.6.0.66
fairscale
pytorch-lightning==2.0.1
fire
fsspec
kornia==0.6.9
natsort
open-clip-torch
chardet==5.1.0
tensorboardx==2.6
pandas
pudb
pyyaml
urllib3<1.27,>=1.25.4
scipy
streamlit>=0.73.1
timm
tokenizers==0.12.1
transformers==4.19.1
triton==2.0.0
torchdata==0.6.1
wandb
invisible-watermark
xformers
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
-e git+https://github.com/openai/CLIP.git@main#egg=clip
-e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
-e .

0
scripts/__init__.py Normal file
View File

0
scripts/demo/__init__.py Normal file
View File

View File

@@ -83,7 +83,7 @@ class GetWatermarkMatch:
def __call__(self, x: np.ndarray) -> np.ndarray:
"""
Detects the number of matching bits the predefined watermark with one
or multiple images. Images should be in cv2 format, e.g. h x w x c.
or multiple images. Images should be in cv2 format, e.g. h x w x c BGR.
Args:
x: ([B], h w, c) in range [0, 255]
@@ -94,7 +94,6 @@ class GetWatermarkMatch:
squeeze = len(x.shape) == 3
if squeeze:
x = x[None, ...]
x = np.flip(x, axis=-1)
bs = x.shape[0]
detected = np.empty((bs, self.num_bits), dtype=bool)

View File

@@ -1,6 +1,30 @@
import os
import numpy as np
import streamlit as st
import torch
from einops import repeat
from pytorch_lightning import seed_everything
from scripts.demo.streamlit_helpers import *
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.api import (
SamplingSpec,
SamplingParams,
ModelArchitecture,
SamplingPipeline,
model_specs,
)
from sgm.inference.helpers import (
get_unique_embedder_keys_from_conditioner,
perform_save_locally,
)
from scripts.demo.streamlit_helpers import (
get_interactive_image,
init_embedder_options,
init_sampling,
init_save_locally,
init_st,
show_samples,
)
SAVE_PATH = "outputs/demo/txt2img/"
@@ -33,48 +57,6 @@ SD_XL_BASE_RATIOS = {
"3.0": (1728, 576),
}
VERSION2SPECS = {
"SD-XL base": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": False,
"config": "configs/inference/sd_xl_base.yaml",
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
"is_guided": True,
},
"sd-2.1": {
"H": 512,
"W": 512,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_2_1.yaml",
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
"is_guided": True,
},
"sd-2.1-768": {
"H": 768,
"W": 768,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_2_1_768.yaml",
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
},
"SDXL-Refiner": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_xl_refiner.yaml",
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
"is_guided": True,
},
}
def load_img(display=True, key=None, device="cuda"):
image = get_interactive_image(key=key)
@@ -95,164 +77,182 @@ def load_img(display=True, key=None, device="cuda"):
def run_txt2img(
state, version, version_dict, is_legacy=False, return_latents=False, filter=None
state,
model_id: ModelArchitecture,
prompt: str,
negative_prompt: str,
return_latents=False,
stage2strength=None,
):
if version == "SD-XL base":
ratio = st.sidebar.selectbox("Ratio:", list(SD_XL_BASE_RATIOS.keys()), 10)
W, H = SD_XL_BASE_RATIOS[ratio]
model: SamplingPipeline = state["model"]
params: SamplingParams = state["params"]
if model_id in sdxl_base_model_list:
width, height = st.selectbox(
"Resolution:", list(SD_XL_BASE_RATIOS.values()), 10
)
else:
H = st.sidebar.number_input(
"H", value=version_dict["H"], min_value=64, max_value=2048
height = int(
st.number_input("H", value=params.height, min_value=64, max_value=2048)
)
W = st.sidebar.number_input(
"W", value=version_dict["W"], min_value=64, max_value=2048
width = int(
st.number_input("W", value=params.width, min_value=64, max_value=2048)
)
C = version_dict["C"]
F = version_dict["f"]
init_dict = {
"orig_width": W,
"orig_height": H,
"target_width": W,
"target_height": H,
}
value_dict = init_embedder_options(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
init_dict,
params = init_embedder_options(
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
params=params,
prompt=prompt,
negative_prompt=negative_prompt,
)
num_rows, num_cols, sampler = init_sampling(
use_identity_guider=not version_dict["is_guided"]
)
params, num_rows, num_cols = init_sampling(params=params)
num_samples = num_rows * num_cols
params.height = height
params.width = width
if st.button("Sample"):
st.write(f"**Model I:** {version}")
out = do_sample(
state["model"],
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
outputs = st.empty()
st.text("Sampling")
out = model.text_to_image(
params=params,
prompt=prompt,
negative_prompt=negative_prompt,
samples=int(num_samples),
return_latents=return_latents,
filter=filter,
noise_strength=stage2strength,
filter=state["filter"],
)
show_samples(out, outputs)
return out
def run_img2img(
state, version_dict, is_legacy=False, return_latents=False, filter=None
state,
prompt: str,
negative_prompt: str,
return_latents=False,
stage2strength=None,
):
model: SamplingPipeline = state["model"]
params: SamplingParams = state["params"]
img = load_img()
if img is None:
return None
H, W = img.shape[2], img.shape[3]
params.height, params.width = img.shape[2], img.shape[3]
init_dict = {
"orig_width": W,
"orig_height": H,
"target_width": W,
"target_height": H,
}
value_dict = init_embedder_options(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
init_dict,
params = init_embedder_options(
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
params=params,
prompt=prompt,
negative_prompt=negative_prompt,
)
strength = st.number_input(
"**Img2Img Strength**", value=0.5, min_value=0.0, max_value=1.0
)
num_rows, num_cols, sampler = init_sampling(
img2img_strength=strength,
use_identity_guider=not version_dict["is_guided"],
params.img2img_strength = st.number_input(
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
)
params, num_rows, num_cols = init_sampling(params=params)
num_samples = num_rows * num_cols
if st.button("Sample"):
out = do_img2img(
repeat(img, "1 ... -> n ...", n=num_samples),
state["model"],
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
outputs = st.empty()
st.text("Sampling")
out = model.image_to_image(
image=repeat(img, "1 ... -> n ...", n=num_samples),
params=params,
prompt=prompt,
negative_prompt=negative_prompt,
samples=int(num_samples),
return_latents=return_latents,
filter=filter,
noise_strength=stage2strength,
filter=state["filter"],
)
show_samples(out, outputs)
return out
def apply_refiner(
input,
state,
sampler,
num_samples,
prompt,
negative_prompt,
filter=None,
num_samples: int,
prompt: str,
negative_prompt: str,
finish_denoising=False,
):
init_dict = {
"orig_width": input.shape[3] * 8,
"orig_height": input.shape[2] * 8,
"target_width": input.shape[3] * 8,
"target_height": input.shape[2] * 8,
}
model: SamplingPipeline = state["model"]
params: SamplingParams = state["params"]
value_dict = init_dict
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
value_dict["crop_coords_top"] = 0
value_dict["crop_coords_left"] = 0
value_dict["aesthetic_score"] = 6.0
value_dict["negative_aesthetic_score"] = 2.5
params.orig_width = input.shape[3] * 8
params.orig_height = input.shape[2] * 8
params.width = input.shape[3] * 8
params.height = input.shape[2] * 8
st.warning(f"refiner input shape: {input.shape}")
samples = do_img2img(
input,
state["model"],
sampler,
value_dict,
num_samples,
skip_encode=True,
filter=filter,
samples = model.refiner(
image=input,
params=params,
prompt=prompt,
negative_prompt=negative_prompt,
samples=num_samples,
return_latents=False,
filter=state["filter"],
add_noise=not finish_denoising,
)
return samples
sdxl_base_model_list = [
ModelArchitecture.SDXL_V1_0_BASE,
ModelArchitecture.SDXL_V0_9_BASE,
]
sdxl_refiner_model_list = [
ModelArchitecture.SDXL_V1_0_REFINER,
ModelArchitecture.SDXL_V0_9_REFINER,
]
if __name__ == "__main__":
st.title("Stable Diffusion")
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
version_dict = VERSION2SPECS[version]
version = st.selectbox(
"Model Version",
[member.value for member in ModelArchitecture],
0,
)
version_enum = ModelArchitecture(version)
specs = model_specs[version_enum]
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
st.write("__________________________")
if version == "SD-XL base":
add_pipeline = st.checkbox("Load SDXL-Refiner?", False)
st.write("**Performance Options:**")
use_fp16 = st.checkbox("Use fp16 (Saves VRAM)", True)
enable_swap = st.checkbox("Swap models to CPU (Saves VRAM, uses RAM)", True)
st.write("__________________________")
if version_enum in sdxl_base_model_list:
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
st.write("__________________________")
else:
add_pipeline = False
filter = DeepFloydDataFiltering(verbose=False)
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
seed = int(
st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
)
seed_everything(seed)
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
state = init_st(version_dict)
if state["msg"]:
st.info(state["msg"])
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version)))
state = init_st(
model_specs[version_enum],
load_filter=True,
use_fp16=use_fp16,
enable_swap=enable_swap,
)
model = state["model"]
is_legacy = version_dict["is_legacy"]
is_legacy = specs.is_legacy
prompt = st.text_input(
"prompt",
@@ -263,47 +263,64 @@ if __name__ == "__main__":
else:
negative_prompt = "" # which is unused
stage2strength = None
finish_denoising = False
if add_pipeline:
st.write("__________________________")
version2 = "SDXL-Refiner"
version2 = ModelArchitecture(
st.selectbox(
"Refiner:",
[member.value for member in sdxl_refiner_model_list],
)
)
st.warning(
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
)
st.write("**Refiner Options:**")
version_dict2 = VERSION2SPECS[version2]
state2 = init_st(version_dict2)
st.info(state2["msg"])
specs2 = model_specs[version2]
state2 = init_st(
specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap
)
params2 = state2["params"]
stage2strength = st.number_input(
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
params2.img2img_strength = st.number_input(
"**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
)
sampler2 = init_sampling(
params2, *_ = init_sampling(
params=state2["params"],
key=2,
img2img_strength=stage2strength,
use_identity_guider=not version_dict["is_guided"],
get_num_samples=False,
specify_num_samples=False,
)
st.write("__________________________")
finish_denoising = st.checkbox("Finish denoising with refiner.", True)
if finish_denoising:
stage2strength = params2.img2img_strength
else:
stage2strength = None
else:
state2 = None
params2 = None
stage2strength = None
if mode == "txt2img":
out = run_txt2img(
state,
version,
version_dict,
is_legacy=is_legacy,
state=state,
model_id=version_enum,
prompt=prompt,
negative_prompt=negative_prompt,
return_latents=add_pipeline,
filter=filter,
stage2strength=stage2strength,
)
elif mode == "img2img":
out = run_img2img(
state,
version_dict,
is_legacy=is_legacy,
state=state,
prompt=prompt,
negative_prompt=negative_prompt,
return_latents=add_pipeline,
filter=filter,
stage2strength=stage2strength,
)
else:
raise ValueError(f"unknown mode {mode}")
@@ -311,18 +328,20 @@ if __name__ == "__main__":
samples, samples_z = out
else:
samples = out
samples_z = None
if add_pipeline:
if add_pipeline and samples_z is not None:
outputs = st.empty()
st.write("**Running Refinement Stage**")
samples = apply_refiner(
samples_z,
state2,
sampler2,
samples_z.shape[0],
input=samples_z,
state=state2,
num_samples=samples_z.shape[0],
prompt=prompt,
negative_prompt=negative_prompt if is_legacy else "",
filter=filter,
finish_denoising=finish_denoising,
)
show_samples(samples, outputs)
if save_locally and samples is not None:
perform_save_locally(save_path, samples)

View File

@@ -1,136 +1,68 @@
import os
from typing import Union, List
import math
import numpy as np
import streamlit as st
import torch
from PIL import Image
from einops import rearrange, repeat
from imwatermark import WatermarkEncoder
from omegaconf import OmegaConf, ListConfig
from torch import autocast
from PIL import Image
from torchvision import transforms
from torchvision.utils import make_grid
from safetensors.torch import load_file as load_safetensors
from typing import Optional, Tuple, Dict, Any
from sgm.modules.diffusionmodules.sampling import (
EulerEDMSampler,
HeunEDMSampler,
EulerAncestralSampler,
DPMPP2SAncestralSampler,
DPMPP2MSampler,
LinearMultistepSampler,
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.api import (
Discretization,
Guider,
Sampler,
SamplingParams,
SamplingSpec,
SamplingPipeline,
Thresholder,
)
from sgm.util import append_dims
from sgm.util import instantiate_from_config
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor):
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, C, H, W) in range [0, 1]
Returns:
same as input but watermarked
"""
# watermarking libary expects input as cv2 format
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(
rearrange(image_np, "(n b) h w c -> n b c h w", n=n)
).to(image.device)
image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
return image
# A fixed 48-bit message that was choosen at random
# WATERMARK_MESSAGE = 0xB3EC907BB19E
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
from sgm.inference.helpers import embed_watermark, CudaModelManager
@st.cache_resource()
def init_st(version_dict, load_ckpt=True):
state = dict()
if not "model" in state:
config = version_dict["config"]
ckpt = version_dict["ckpt"]
def init_st(
spec: SamplingSpec,
load_ckpt=True,
load_filter=True,
use_fp16=True,
enable_swap=True,
) -> Dict[str, Any]:
state: Dict[str, Any] = dict()
config = spec.config
ckpt = spec.ckpt
config = OmegaConf.load(config)
model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
state["msg"] = msg
state["model"] = model
state["ckpt"] = ckpt if load_ckpt else None
state["config"] = config
return state
def load_model_from_config(config, ckpt=None, verbose=True):
model = instantiate_from_config(config.model)
if ckpt is not None:
print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
global_step = pl_sd["global_step"]
st.info(f"loaded ckpt from global step {global_step}")
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
else:
raise NotImplementedError
msg = None
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
if enable_swap:
pipeline = SamplingPipeline(
model_spec=spec,
use_fp16=use_fp16,
device=CudaModelManager(device="cuda", swap_device="cpu"),
)
else:
msg = None
pipeline = SamplingPipeline(model_spec=spec, use_fp16=use_fp16)
model.cuda()
model.eval()
return model, msg
state["spec"] = spec
state["model"] = pipeline
state["ckpt"] = ckpt if load_ckpt else None
state["config"] = config
state["params"] = spec.default_params
if load_filter:
state["filter"] = DeepFloydDataFiltering(verbose=False)
else:
state["filter"] = None
return state
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
# Hardcoded demo settings; might undergo some changes in the future
value_dict = {}
def init_embedder_options(
keys, params: SamplingParams, prompt=None, negative_prompt=None
) -> SamplingParams:
for key in keys:
if key == "txt":
if prompt is None:
@@ -140,57 +72,38 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
if negative_prompt is None:
negative_prompt = st.text_input("Negative prompt", "")
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
if key == "original_size_as_tuple":
orig_width = st.number_input(
"orig_width",
value=init_dict["orig_width"],
value=params.orig_width,
min_value=16,
)
orig_height = st.number_input(
"orig_height",
value=init_dict["orig_height"],
value=params.orig_height,
min_value=16,
)
value_dict["orig_width"] = orig_width
value_dict["orig_height"] = orig_height
params.orig_width = int(orig_width)
params.orig_height = int(orig_height)
if key == "crop_coords_top_left":
crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0)
crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0)
value_dict["crop_coords_top"] = crop_coord_top
value_dict["crop_coords_left"] = crop_coord_left
if key == "aesthetic_score":
value_dict["aesthetic_score"] = 6.0
value_dict["negative_aesthetic_score"] = 2.5
if key == "target_size_as_tuple":
target_width = st.number_input(
"target_width",
value=init_dict["target_width"],
min_value=16,
crop_coord_top = st.number_input(
"crop_coords_top", value=params.crop_coords_top, min_value=0
)
target_height = st.number_input(
"target_height",
value=init_dict["target_height"],
min_value=16,
crop_coord_left = st.number_input(
"crop_coords_left", value=params.crop_coords_left, min_value=0
)
value_dict["target_width"] = target_width
value_dict["target_height"] = target_height
return value_dict
params.crop_coords_top = int(crop_coord_top)
params.crop_coords_left = int(crop_coord_left)
return params
def perform_save_locally(save_path, samples):
os.makedirs(os.path.join(save_path), exist_ok=True)
base_count = len(os.listdir(os.path.join(save_path)))
samples = embed_watemark(samples)
samples = embed_watermark(samples)
for sample in samples:
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save(
@@ -209,48 +122,26 @@ def init_save_locally(_dir, init_value: bool = False):
return save_locally, save_path
class Img2ImgDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 1.0):
self.discretization = discretization
self.strength = strength
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def show_samples(samples, outputs):
if isinstance(samples, tuple):
samples, _ = samples
grid = embed_watermark(torch.stack([samples]))
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
def get_guider(key):
guider = st.sidebar.selectbox(
f"Discretization #{key}",
[
"VanillaCFG",
"IdentityGuider",
],
def get_guider(params: SamplingParams, key=1) -> SamplingParams:
params.guider = Guider(
st.sidebar.selectbox(
f"Discretization #{key}", [member.value for member in Guider]
)
)
if guider == "IdentityGuider":
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
elif guider == "VanillaCFG":
if params.guider == Guider.VANILLA:
scale = st.number_input(
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
f"cfg-scale #{key}", value=params.scale, min_value=0.0, max_value=100.0
)
params.scale = scale
thresholder = st.sidebar.selectbox(
f"Thresholder #{key}",
[
@@ -259,179 +150,97 @@ def get_guider(key):
)
if thresholder == "None":
dyn_thresh_config = {
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
}
params.thresholder = Thresholder.NONE
else:
raise NotImplementedError
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
}
else:
raise NotImplementedError
return guider_config
return params
def init_sampling(
key=1, img2img_strength=1.0, use_identity_guider=False, get_num_samples=True
):
if get_num_samples:
num_rows = 1
params: SamplingParams,
key=1,
specify_num_samples=True,
) -> Tuple[SamplingParams, int, int]:
num_rows, num_cols = 1, 1
if specify_num_samples:
num_cols = st.number_input(
f"num cols #{key}", value=2, min_value=1, max_value=10
)
steps = st.sidebar.number_input(
f"steps #{key}", value=50, min_value=1, max_value=1000
)
sampler = st.sidebar.selectbox(
f"Sampler #{key}",
[
"EulerEDMSampler",
"HeunEDMSampler",
"EulerAncestralSampler",
"DPMPP2SAncestralSampler",
"DPMPP2MSampler",
"LinearMultistepSampler",
],
0,
)
discretization = st.sidebar.selectbox(
f"Discretization #{key}",
[
"LegacyDDPMDiscretization",
"EDMDiscretization",
],
params.steps = int(
st.sidebar.number_input(
f"steps #{key}", value=params.steps, min_value=1, max_value=1000
)
)
discretization_config = get_discretization(discretization, key=key)
guider_config = get_guider(key=key)
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
if img2img_strength < 1.0:
st.warning(
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
params.sampler = Sampler(
st.sidebar.selectbox(
f"Sampler #{key}",
[member.value for member in Sampler],
0,
)
sampler.discretization = Img2ImgDiscretizationWrapper(
sampler.discretization, strength=img2img_strength
)
params.discretization = Discretization(
st.sidebar.selectbox(
f"Discretization #{key}",
[member.value for member in Discretization],
)
if get_num_samples:
return num_rows, num_cols, sampler
return sampler
)
params = get_discretization(params=params, key=key)
params = get_guider(params=params, key=key)
params = get_sampler(params=params, key=key)
return params, num_rows, num_cols
def get_discretization(discretization, key=1):
if discretization == "LegacyDDPMDiscretization":
use_new_range = st.checkbox(f"Start from highest noise level? #{key}", False)
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
"params": {"legacy_range": not use_new_range},
}
elif discretization == "EDMDiscretization":
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146
rho = st.number_input(f"rho #{key}", value=3.0)
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
"params": {
"sigma_min": sigma_min,
"sigma_max": sigma_max,
"rho": rho,
},
}
return discretization_config
def get_discretization(params: SamplingParams, key=1) -> SamplingParams:
if params.discretization == Discretization.EDM:
params.sigma_min = st.number_input(f"sigma_min #{key}", value=params.sigma_min)
params.sigma_max = st.number_input(f"sigma_max #{key}", value=params.sigma_max)
params.rho = st.number_input(f"rho #{key}", value=params.rho)
return params
def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
if sampler_name == "EulerEDMSampler":
sampler = EulerEDMSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=s_churn,
s_tmin=s_tmin,
s_tmax=s_tmax,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "HeunEDMSampler":
sampler = HeunEDMSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=s_churn,
s_tmin=s_tmin,
s_tmax=s_tmax,
s_noise=s_noise,
verbose=True,
)
elif (
sampler_name == "EulerAncestralSampler"
or sampler_name == "DPMPP2SAncestralSampler"
):
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
if sampler_name == "EulerAncestralSampler":
sampler = EulerAncestralSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=eta,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "DPMPP2SAncestralSampler":
sampler = DPMPP2SAncestralSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=eta,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "DPMPP2MSampler":
sampler = DPMPP2MSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
verbose=True,
def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
if params.sampler in (Sampler.EULER_EDM, Sampler.HEUN_EDM):
params.s_churn = st.sidebar.number_input(
f"s_churn #{key}", value=params.s_churn, min_value=0.0
)
elif sampler_name == "LinearMultistepSampler":
order = st.sidebar.number_input("order", value=4, min_value=1)
sampler = LinearMultistepSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
order=order,
verbose=True,
params.s_tmin = st.sidebar.number_input(
f"s_tmin #{key}", value=params.s_tmin, min_value=0.0
)
params.s_tmax = st.sidebar.number_input(
f"s_tmax #{key}", value=params.s_tmax, min_value=0.0
)
params.s_noise = st.sidebar.number_input(
f"s_noise #{key}", value=params.s_noise, min_value=0.0
)
else:
raise ValueError(f"unknown sampler {sampler_name}!")
return sampler
elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL):
params.s_noise = st.sidebar.number_input(
"s_noise", value=params.s_noise, min_value=0.0
)
params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0)
elif params.sampler == Sampler.LINEAR_MULTISTEP:
params.order = int(
st.sidebar.number_input("order", value=params.order, min_value=1)
)
return params
def get_interactive_image(key=None) -> Image.Image:
def get_interactive_image(key=None) -> Optional[Image.Image]:
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
if image is not None:
image = Image.open(image)
if not image.mode == "RGB":
image = image.convert("RGB")
return image
return None
def load_img(display=True, key=None):
def load_img(display=True, key=None) -> Optional[torch.Tensor]:
image = get_interactive_image(key=key)
if image is None:
return None
@@ -455,214 +264,3 @@ def get_init_img(batch_size=1, key=None):
init_image = load_img(key=key).cuda()
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
return init_image
def do_sample(
model,
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings: List = None,
batch2model_input: List = None,
return_latents=False,
filter=None,
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
if batch2model_input is None:
batch2model_input = []
st.text("Sampling")
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
num_samples = [num_samples]
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
)
additional_model_inputs = {}
for k in batch2model_input:
additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to("cuda")
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
grid = torch.stack([samples])
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
# Hardcoded demo setups; might undergo some changes in the future
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor(
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
)
.to(device)
.repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device)
.repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(*N, 1)
)
else:
batch[key] = value_dict[key]
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
@torch.no_grad()
def do_img2img(
img,
model,
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=[],
additional_kwargs={},
offset_noise_level: int = 0.0,
return_latents=False,
skip_encode=False,
filter=None,
):
st.text("Sampling")
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[num_samples],
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
for k in additional_kwargs:
c[k] = uc[k] = additional_kwargs[k]
if skip_encode:
z = img
else:
z = model.encode_first_stage(img)
noise = torch.randn_like(z)
sigmas = sampler.discretization(sampler.num_steps)
sigma = sigmas[0]
st.info(f"all sigmas: {sigmas}")
st.info(f"noising sigma: {sigma}")
if offset_noise_level > 0.0:
noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim
)
noised_z = z + noise * append_dims(sigma, z.ndim)
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c)
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
grid = embed_watemark(torch.stack([samples]))
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples

319
scripts/tests/attention.py Normal file
View File

@@ -0,0 +1,319 @@
import torch
import einops
from torch.backends.cuda import SDPBackend
import torch.nn.functional as F
import torch.utils.benchmark as benchmark
from sgm.modules.attention import SpatialTransformer, BasicTransformerBlock
def benchmark_attn():
# Lets define a helpful benchmarking function:
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
device = "cuda" if torch.cuda.is_available() else "cpu"
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32
dtype = torch.float16
query = torch.rand(
batch_size,
num_heads,
max_sequence_len,
embed_dimension,
device=device,
dtype=dtype,
)
key = torch.rand(
batch_size,
num_heads,
max_sequence_len,
embed_dimension,
device=device,
dtype=dtype,
)
value = torch.rand(
batch_size,
num_heads,
max_sequence_len,
embed_dimension,
device=device,
dtype=dtype,
)
print(f"q/k/v shape:", query.shape, key.shape, value.shape)
# Lets explore the speed of each of the 3 implementations
from torch.backends.cuda import SDPBackend, sdp_kernel
# Helpful arguments mapper
backend_map = {
SDPBackend.MATH: {
"enable_math": True,
"enable_flash": False,
"enable_mem_efficient": False,
},
SDPBackend.FLASH_ATTENTION: {
"enable_math": False,
"enable_flash": True,
"enable_mem_efficient": False,
},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False,
"enable_flash": False,
"enable_mem_efficient": True,
},
}
from torch.profiler import ProfilerActivity, profile, record_function
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
print(
f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("Default detailed stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print(
f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)
with sdp_kernel(**backend_map[SDPBackend.MATH]):
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("Math implmentation stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
try:
print(
f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("FlashAttention stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
try:
print(
f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)
except RuntimeError:
print("EfficientAttention is not supported. See warnings for reasons.")
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("EfficientAttention stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
def run_model(model, x, context):
return model(x, context)
def benchmark_transformer_blocks():
device = "cuda" if torch.cuda.is_available() else "cpu"
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
checkpoint = True
compile = False
batch_size = 32
h, w = 64, 64
context_len = 77
embed_dimension = 1024
context_dim = 1024
d_head = 64
transformer_depth = 4
n_heads = embed_dimension // d_head
dtype = torch.float16
model_native = SpatialTransformer(
embed_dimension,
n_heads,
d_head,
context_dim=context_dim,
use_linear=True,
use_checkpoint=checkpoint,
attn_type="softmax",
depth=transformer_depth,
sdp_backend=SDPBackend.FLASH_ATTENTION,
).to(device)
model_efficient_attn = SpatialTransformer(
embed_dimension,
n_heads,
d_head,
context_dim=context_dim,
use_linear=True,
depth=transformer_depth,
use_checkpoint=checkpoint,
attn_type="softmax-xformers",
).to(device)
if not checkpoint and compile:
print("compiling models")
model_native = torch.compile(model_native)
model_efficient_attn = torch.compile(model_efficient_attn)
x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
from torch.profiler import ProfilerActivity, profile, record_function
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
with torch.autocast("cuda"):
print(
f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
)
print(
f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
)
print(75 * "+")
print("NATIVE")
print(75 * "+")
torch.cuda.reset_peak_memory_stats()
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("NativeAttention stats"):
for _ in range(25):
model_native(x, c)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
print(75 * "+")
print("Xformers")
print(75 * "+")
torch.cuda.reset_peak_memory_stats()
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("xformers stats"):
for _ in range(25):
model_efficient_attn(x, c)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
def test01():
# conv1x1 vs linear
from sgm.util import count_params
conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda()
print(count_params(conv))
linear = torch.nn.Linear(3, 32).cuda()
print(count_params(linear))
print(conv.weight.shape)
# use same initialization
linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
linear.bias = torch.nn.Parameter(conv.bias)
print(linear.weight.shape)
x = torch.randn(11, 3, 64, 64).cuda()
xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous()
print(xr.shape)
out_linear = linear(xr)
print(out_linear.mean(), out_linear.shape)
out_conv = conv(x)
print(out_conv.mean(), out_conv.shape)
print("done with test01.\n")
def test02():
# try cosine flash attention
import time
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
print("testing cosine flash attention...")
DIM = 1024
SEQLEN = 4096
BS = 16
print(" softmax (vanilla) first...")
model = BasicTransformerBlock(
dim=DIM,
n_heads=16,
d_head=64,
dropout=0.0,
context_dim=None,
attn_mode="softmax",
).cuda()
try:
x = torch.randn(BS, SEQLEN, DIM).cuda()
tic = time.time()
y = model(x)
toc = time.time()
print(y.shape, toc - tic)
except RuntimeError as e:
# likely oom
print(str(e))
print("\n now flash-cosine...")
model = BasicTransformerBlock(
dim=DIM,
n_heads=16,
d_head=64,
dropout=0.0,
context_dim=None,
attn_mode="flash-cosine",
).cuda()
x = torch.randn(BS, SEQLEN, DIM).cuda()
tic = time.time()
y = model(x)
toc = time.time()
print(y.shape, toc - tic)
print("done with test02.\n")
if __name__ == "__main__":
# test01()
# test02()
# test03()
# benchmark_attn()
benchmark_transformer_blocks()
print("done.")

0
scripts/util/__init__.py Normal file
View File

View File

View File

@@ -1,9 +1,10 @@
import os
import torch
import clip
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
import clip
RESOURCES_ROOT = "scripts/util/detection/"

View File

@@ -1,13 +0,0 @@
from setuptools import find_packages, setup
setup(
name="sgm",
version="0.0.1",
packages=find_packages(),
python_requires=">=3.8",
py_modules=["sgm"],
description="Stability Generative Models",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
url="https://github.com/Stability-AI/generative-models",
)

View File

@@ -1,3 +1,4 @@
from .data import StableDataModuleFromConfig
from .models import AutoencodingEngine, DiffusionEngine
from .util import instantiate_from_config
from .util import get_configs_path, instantiate_from_config
__version__ = "0.1.1"

View File

@@ -1,7 +1,7 @@
import torchvision
import pytorch_lightning as pl
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
class CIFAR10DataDictWrapper(Dataset):

View File

@@ -1,7 +1,7 @@
import torchvision
import pytorch_lightning as pl
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
class MNISTDataDictWrapper(Dataset):

522
sgm/inference/api.py Normal file
View File

@@ -0,0 +1,522 @@
from dataclasses import dataclass, asdict
from enum import Enum
from omegaconf import OmegaConf
import os
from sgm.inference.helpers import (
do_sample,
do_img2img,
DeviceModelManager,
get_model_manager,
Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper,
)
from sgm.modules.diffusionmodules.sampling import (
EulerEDMSampler,
HeunEDMSampler,
EulerAncestralSampler,
DPMPP2SAncestralSampler,
DPMPP2MSampler,
LinearMultistepSampler,
)
from sgm.util import load_model_from_config, get_configs_path, get_checkpoints_path
import torch
from typing import Optional, Dict, Any, Union
class ModelArchitecture(str, Enum):
SDXL_V1_0_BASE = "stable-diffusion-xl-v1-base"
SDXL_V1_0_REFINER = "stable-diffusion-xl-v1-refiner"
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
SD_2_1 = "stable-diffusion-v2-1"
SD_2_1_768 = "stable-diffusion-v2-1-768"
class Sampler(str, Enum):
EULER_EDM = "EulerEDMSampler"
HEUN_EDM = "HeunEDMSampler"
EULER_ANCESTRAL = "EulerAncestralSampler"
DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
DPMPP2M = "DPMPP2MSampler"
LINEAR_MULTISTEP = "LinearMultistepSampler"
class Discretization(str, Enum):
LEGACY_DDPM = "LegacyDDPMDiscretization"
EDM = "EDMDiscretization"
class Guider(str, Enum):
VANILLA = "VanillaCFG"
IDENTITY = "IdentityGuider"
class Thresholder(str, Enum):
NONE = "None"
@dataclass
class SamplingParams:
"""
Parameters for sampling.
"""
width: Optional[int] = None
height: Optional[int] = None
steps: Optional[int] = None
sampler: Sampler = Sampler.EULER_EDM
discretization: Discretization = Discretization.LEGACY_DDPM
guider: Guider = Guider.VANILLA
thresholder: Thresholder = Thresholder.NONE
scale: float = 5.0
aesthetic_score: float = 6.0
negative_aesthetic_score: float = 2.5
img2img_strength: float = 1.0
orig_width: int = 1024
orig_height: int = 1024
crop_coords_top: int = 0
crop_coords_left: int = 0
sigma_min: float = 0.0292
sigma_max: float = 14.6146
rho: float = 3.0
s_churn: float = 0.0
s_tmin: float = 0.0
s_tmax: float = 999.0
s_noise: float = 1.0
eta: float = 1.0
order: int = 4
@dataclass
class SamplingSpec:
width: int
height: int
channels: int
factor: int
is_legacy: bool
config: str
ckpt: str
is_guided: bool
default_params: SamplingParams
# The defaults here are derived from user preference testing.
model_specs = {
ModelArchitecture.SD_2_1: SamplingSpec(
height=512,
width=512,
channels=4,
factor=8,
is_legacy=True,
config="sd_2_1.yaml",
ckpt="v2-1_512-ema-pruned.safetensors",
is_guided=True,
default_params=SamplingParams(
width=512,
height=512,
steps=40,
scale=7.0,
),
),
ModelArchitecture.SD_2_1_768: SamplingSpec(
height=768,
width=768,
channels=4,
factor=8,
is_legacy=True,
config="sd_2_1_768.yaml",
ckpt="v2-1_768-ema-pruned.safetensors",
is_guided=True,
default_params=SamplingParams(
width=768,
height=768,
steps=40,
scale=7.0,
),
),
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
height=1024,
width=1024,
channels=4,
factor=8,
is_legacy=False,
config="sd_xl_base.yaml",
ckpt="sd_xl_base_0.9.safetensors",
is_guided=True,
default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0),
),
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
height=1024,
width=1024,
channels=4,
factor=8,
is_legacy=True,
config="sd_xl_refiner.yaml",
ckpt="sd_xl_refiner_0.9.safetensors",
is_guided=True,
default_params=SamplingParams(
width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15
),
),
ModelArchitecture.SDXL_V1_0_BASE: SamplingSpec(
height=1024,
width=1024,
channels=4,
factor=8,
is_legacy=False,
config="sd_xl_base.yaml",
ckpt="sd_xl_base_1.0.safetensors",
is_guided=True,
default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0),
),
ModelArchitecture.SDXL_V1_0_REFINER: SamplingSpec(
height=1024,
width=1024,
channels=4,
factor=8,
is_legacy=True,
config="sd_xl_refiner.yaml",
ckpt="sd_xl_refiner_1.0.safetensors",
is_guided=True,
default_params=SamplingParams(
width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15
),
),
}
def wrap_discretization(
discretization, image_strength=None, noise_strength=None, steps=None
):
if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance(
discretization, Txt2NoisyDiscretizationWrapper
):
return discretization # Already wrapped
if image_strength is not None and image_strength < 1.0 and image_strength > 0.0:
discretization = Img2ImgDiscretizationWrapper(
discretization, strength=image_strength
)
if (
noise_strength is not None
and noise_strength < 1.0
and noise_strength > 0.0
and steps is not None
):
discretization = Txt2NoisyDiscretizationWrapper(
discretization,
strength=noise_strength,
original_steps=steps,
)
return discretization
class SamplingPipeline:
def __init__(
self,
model_id: Optional[ModelArchitecture] = None,
model_spec: Optional[SamplingSpec] = None,
model_path: Optional[str] = None,
config_path: Optional[str] = None,
use_fp16: bool = True,
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
) -> None:
"""
Sampling pipeline for generating images from a model.
@param model_id: Model architecture to use. If not specified, model_spec must be specified.
@param model_spec: Model specification to use. If not specified, model_id must be specified.
@param model_path: Path to model checkpoints folder.
@param config_path: Path to model config folder.
@param use_fp16: Whether to use fp16 for sampling.
@param device: Device manager to use with this pipeline. If a string or torch.device is passed, a device manager will be created based on device type if possible.
"""
self.model_id = model_id
if model_spec is not None:
self.specs = model_spec
elif model_id is not None:
if model_id not in model_specs:
raise ValueError(f"Model {model_id} not supported")
self.specs = model_specs[model_id]
else:
raise ValueError("Either model_id or model_spec should be provided")
if model_path is None:
model_path = get_checkpoints_path()
if config_path is None:
config_path = get_configs_path()
self.config = os.path.join(config_path, "inference", self.specs.config)
self.ckpt = os.path.join(model_path, self.specs.ckpt)
if not os.path.exists(self.config):
raise ValueError(
f"Config {self.config} not found, check model spec or config_path"
)
if not os.path.exists(self.ckpt):
raise ValueError(
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
)
self.device_manager = get_model_manager(device)
self.model = self._load_model(
device_manager=self.device_manager, use_fp16=use_fp16
)
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
config = OmegaConf.load(self.config)
model = load_model_from_config(config, self.ckpt)
if model is None:
raise ValueError(f"Model {self.model_id} could not be loaded")
device_manager.load(model)
if use_fp16:
model.conditioner.half()
model.model.half()
return model
def text_to_image(
self,
prompt: str,
params: Optional[SamplingParams] = None,
negative_prompt: str = "",
samples: int = 1,
return_latents: bool = False,
noise_strength: Optional[float] = None,
filter=None,
):
if params is None:
params = self.specs.default_params
else:
# Set defaults if optional params are not specified
if params.width is None:
params.width = self.specs.default_params.width
if params.height is None:
params.height = self.specs.default_params.height
if params.steps is None:
params.steps = self.specs.default_params.steps
sampler = get_sampler_config(params)
sampler.discretization = wrap_discretization(
sampler.discretization,
image_strength=None,
noise_strength=noise_strength,
steps=params.steps,
)
value_dict = asdict(params)
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
value_dict["target_width"] = params.width
value_dict["target_height"] = params.height
return do_sample(
self.model,
sampler,
value_dict,
samples,
params.height,
params.width,
self.specs.channels,
self.specs.factor,
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents,
filter=filter,
device=self.device_manager,
)
def image_to_image(
self,
image: torch.Tensor,
prompt: str,
params: Optional[SamplingParams] = None,
negative_prompt: str = "",
samples: int = 1,
return_latents: bool = False,
noise_strength: Optional[float] = None,
filter=None,
):
if params is None:
params = self.specs.default_params
sampler = get_sampler_config(params)
sampler.discretization = wrap_discretization(
sampler.discretization,
image_strength=params.img2img_strength,
noise_strength=noise_strength,
steps=params.steps,
)
height, width = image.shape[2], image.shape[3]
value_dict = asdict(params)
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
value_dict["target_width"] = width
value_dict["target_height"] = height
value_dict["orig_width"] = width
value_dict["orig_height"] = height
return do_img2img(
image,
self.model,
sampler,
value_dict,
samples,
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents,
filter=filter,
device=self.device_manager,
)
def refiner(
self,
image: torch.Tensor,
prompt: str,
negative_prompt: str = "",
params: Optional[SamplingParams] = None,
samples: int = 1,
return_latents: bool = False,
filter: Any = None,
add_noise: bool = False,
):
if params is None:
params = self.specs.default_params
sampler = get_sampler_config(params)
value_dict = {
"orig_width": image.shape[3] * 8,
"orig_height": image.shape[2] * 8,
"target_width": image.shape[3] * 8,
"target_height": image.shape[2] * 8,
"prompt": prompt,
"negative_prompt": negative_prompt,
"crop_coords_top": 0,
"crop_coords_left": 0,
"aesthetic_score": 6.0,
"negative_aesthetic_score": 2.5,
}
sampler.discretization = wrap_discretization(
sampler.discretization, image_strength=params.img2img_strength
)
return do_img2img(
image,
self.model,
sampler,
value_dict,
samples,
skip_encode=True,
return_latents=return_latents,
filter=filter,
add_noise=add_noise,
device=self.device_manager,
)
def get_guider_config(params: SamplingParams) -> Dict[str, Any]:
guider_config: Dict[str, Any]
if params.guider == Guider.IDENTITY:
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
elif params.guider == Guider.VANILLA:
scale = params.scale
thresholder = params.thresholder
if thresholder == Thresholder.NONE:
dyn_thresh_config = {
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
}
else:
raise NotImplementedError
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
}
else:
raise NotImplementedError
return guider_config
def get_discretization_config(params: SamplingParams) -> Dict[str, Any]:
discretization_config: Dict[str, Any]
if params.discretization == Discretization.LEGACY_DDPM:
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
}
elif params.discretization == Discretization.EDM:
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
"params": {
"sigma_min": params.sigma_min,
"sigma_max": params.sigma_max,
"rho": params.rho,
},
}
else:
raise ValueError(f"unknown discretization {params.discretization}")
return discretization_config
def get_sampler_config(params: SamplingParams):
discretization_config = get_discretization_config(params)
guider_config = get_guider_config(params)
sampler = None
if params.sampler == Sampler.EULER_EDM:
return EulerEDMSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=params.s_churn,
s_tmin=params.s_tmin,
s_tmax=params.s_tmax,
s_noise=params.s_noise,
verbose=True,
)
if params.sampler == Sampler.HEUN_EDM:
return HeunEDMSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=params.s_churn,
s_tmin=params.s_tmin,
s_tmax=params.s_tmax,
s_noise=params.s_noise,
verbose=True,
)
if params.sampler == Sampler.EULER_ANCESTRAL:
return EulerAncestralSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=params.eta,
s_noise=params.s_noise,
verbose=True,
)
if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
return DPMPP2SAncestralSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=params.eta,
s_noise=params.s_noise,
verbose=True,
)
if params.sampler == Sampler.DPMPP2M:
return DPMPP2MSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
verbose=True,
)
if params.sampler == Sampler.LINEAR_MULTISTEP:
return LinearMultistepSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
order=params.order,
verbose=True,
)
raise ValueError(f"unknown sampler {params.sampler}!")

441
sgm/inference/helpers.py Normal file
View File

@@ -0,0 +1,441 @@
import contextlib
import os
from typing import Union, List, Optional
import math
import numpy as np
import torch
from PIL import Image
from einops import rearrange
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig
from sgm.util import append_dims
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor):
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, C, H, W) in range [0, 1]
Returns:
same as input but watermarked
"""
# watermarking libary expects input as cv2 BGR format
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
).to(image.device)
image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
return image
# A fixed 48-bit message that was choosen at random
# WATERMARK_MESSAGE = 0xB3EC907BB19E
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
class DeviceModelManager(object):
"""
Default model loading class, should work for all device classes.
"""
def __init__(
self,
device: Union[torch.device, str],
swap_device: Optional[Union[torch.device, str]] = None,
) -> None:
"""
Args:
device (Union[torch.device, str]): The device to use for the model.
"""
self.device = torch.device(device)
self.swap_device = (
torch.device(swap_device) if swap_device is not None else self.device
)
def load(self, model: torch.nn.Module) -> None:
"""
Loads a model to the (swap) device.
"""
model.to(self.swap_device)
def autocast(self):
"""
Context manager that enables autocast for the device if supported.
"""
if self.device.type not in ("cuda", "cpu"):
return contextlib.nullcontext()
return torch.autocast(self.device.type)
@contextlib.contextmanager
def use(self, model: torch.nn.Module):
"""
Context manager that ensures a model is on the correct device during use.
The default model loader does not perform any swapping, so the model will
stay on device.
"""
try:
model.to(self.device)
yield
finally:
if self.device != self.swap_device:
model.to(self.swap_device)
class CudaModelManager(DeviceModelManager):
"""
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
"""
@contextlib.contextmanager
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
"""
Context manager that ensures a model is on the correct device during use.
If a swap device was provided, this will move the model to it after use and clear cache.
"""
model.to(self.device)
yield
if self.device != self.swap_device:
model.to(self.swap_device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_unique_embedder_keys_from_conditioner(conditioner):
return list({x.input_key for x in conditioner.embedders})
def perform_save_locally(save_path, samples):
os.makedirs(os.path.join(save_path), exist_ok=True)
base_count = len(os.listdir(os.path.join(save_path)))
samples = embed_watermark(samples)
for sample in samples:
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save(
os.path.join(save_path, f"{base_count:09}.png")
)
base_count += 1
def get_model_manager(
device: Optional[Union[DeviceModelManager, str, torch.device]]
) -> DeviceModelManager:
if isinstance(device, DeviceModelManager):
return device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
if device.type == "cuda":
return CudaModelManager(device=device)
else:
return DeviceModelManager(device=device)
class Img2ImgDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 1.0):
self.discretization = discretization
self.strength = strength
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
class Txt2NoisyDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
self.discretization = discretization
self.strength = strength
self.original_steps = original_steps
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
if self.original_steps is None:
steps = len(sigmas)
else:
steps = self.original_steps + 1
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
sigmas = sigmas[prune_index:]
print("prune index:", prune_index)
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def do_sample(
model,
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings: Optional[List] = None,
batch2model_input: Optional[List] = None,
return_latents=False,
filter=None,
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
if batch2model_input is None:
batch2model_input = []
device_manager = get_model_manager(device=device)
with torch.no_grad():
with device_manager.autocast():
with model.ema_scope():
num_samples = [num_samples]
with device_manager.use(model.conditioner):
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to(
device_manager.device
),
(c, uc),
)
additional_model_inputs = {}
for k in batch2model_input:
additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to(device_manager.device)
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
with device_manager.use(model.denoiser):
with device_manager.use(model.model):
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
with device_manager.use(model.first_stage_model):
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
if return_latents:
return samples, samples_z
return samples
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
# Hardcoded demo setups; might undergo some changes in the future
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor(
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
)
.to(device)
.repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device)
.repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(*N, 1)
)
else:
batch[key] = value_dict[key]
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
def get_input_image_tensor(image: Image.Image, device="cuda"):
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = image.resize((width, height))
image_array = np.array(image.convert("RGB"))
image_array = image_array[None].transpose(0, 3, 1, 2)
image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
return image_tensor.to(device)
def do_img2img(
img,
model,
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=[],
additional_kwargs={},
offset_noise_level: float = 0.0,
return_latents=False,
skip_encode=False,
filter=None,
add_noise=True,
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
):
device_manager = get_model_manager(device)
with torch.no_grad():
with device_manager.autocast():
with model.ema_scope():
with device_manager.use(model.conditioner):
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[num_samples],
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
c[k], uc[k] = map(
lambda y: y[k][:num_samples].to(device_manager.device), (c, uc)
)
for k in additional_kwargs:
c[k] = uc[k] = additional_kwargs[k]
if skip_encode:
z = img
else:
with device_manager.use(model.first_stage_model):
z = model.encode_first_stage(img)
noise = torch.randn_like(z)
sigmas = sampler.discretization(sampler.num_steps)
sigma = sigmas[0].to(z.device)
if offset_noise_level > 0.0:
noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim
)
if add_noise:
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
else:
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c)
with device_manager.use(model.denoiser):
with device_manager.use(model.model):
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
with device_manager.use(model.first_stage_model):
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
if return_latents:
return samples, samples_z
return samples

View File

@@ -258,14 +258,10 @@ class DiffusionEngine(pl.LightningModule):
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
else:
raise NotImplementedError()
elif isinstance(x, Union[List, ListConfig]):
elif isinstance(x, (List, ListConfig)):
if isinstance(x[0], str):
# strings
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
elif isinstance(x[0], Union[ListConfig, List]):
# # case: videos processed
x = [xx[0] for xx in x]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
else:
raise NotImplementedError()
else:

View File

@@ -631,317 +631,3 @@ class SpatialTransformer(nn.Module):
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
def benchmark_attn():
# Lets define a helpful benchmarking function:
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
device = "cuda" if torch.cuda.is_available() else "cpu"
import torch.nn.functional as F
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32
dtype = torch.float16
query = torch.rand(
batch_size,
num_heads,
max_sequence_len,
embed_dimension,
device=device,
dtype=dtype,
)
key = torch.rand(
batch_size,
num_heads,
max_sequence_len,
embed_dimension,
device=device,
dtype=dtype,
)
value = torch.rand(
batch_size,
num_heads,
max_sequence_len,
embed_dimension,
device=device,
dtype=dtype,
)
print(f"q/k/v shape:", query.shape, key.shape, value.shape)
# Lets explore the speed of each of the 3 implementations
from torch.backends.cuda import SDPBackend, sdp_kernel
# Helpful arguments mapper
backend_map = {
SDPBackend.MATH: {
"enable_math": True,
"enable_flash": False,
"enable_mem_efficient": False,
},
SDPBackend.FLASH_ATTENTION: {
"enable_math": False,
"enable_flash": True,
"enable_mem_efficient": False,
},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False,
"enable_flash": False,
"enable_mem_efficient": True,
},
}
from torch.profiler import ProfilerActivity, profile, record_function
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
print(
f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("Default detailed stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print(
f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)
with sdp_kernel(**backend_map[SDPBackend.MATH]):
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("Math implmentation stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
try:
print(
f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("FlashAttention stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
try:
print(
f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)
except RuntimeError:
print("EfficientAttention is not supported. See warnings for reasons.")
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("EfficientAttention stats"):
for _ in range(25):
o = F.scaled_dot_product_attention(query, key, value)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
def run_model(model, x, context):
return model(x, context)
def benchmark_transformer_blocks():
device = "cuda" if torch.cuda.is_available() else "cpu"
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
checkpoint = True
compile = False
batch_size = 32
h, w = 64, 64
context_len = 77
embed_dimension = 1024
context_dim = 1024
d_head = 64
transformer_depth = 4
n_heads = embed_dimension // d_head
dtype = torch.float16
model_native = SpatialTransformer(
embed_dimension,
n_heads,
d_head,
context_dim=context_dim,
use_linear=True,
use_checkpoint=checkpoint,
attn_type="softmax",
depth=transformer_depth,
sdp_backend=SDPBackend.FLASH_ATTENTION,
).to(device)
model_efficient_attn = SpatialTransformer(
embed_dimension,
n_heads,
d_head,
context_dim=context_dim,
use_linear=True,
depth=transformer_depth,
use_checkpoint=checkpoint,
attn_type="softmax-xformers",
).to(device)
if not checkpoint and compile:
print("compiling models")
model_native = torch.compile(model_native)
model_efficient_attn = torch.compile(model_efficient_attn)
x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
from torch.profiler import ProfilerActivity, profile, record_function
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
with torch.autocast("cuda"):
print(
f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
)
print(
f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
)
print(75 * "+")
print("NATIVE")
print(75 * "+")
torch.cuda.reset_peak_memory_stats()
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("NativeAttention stats"):
for _ in range(25):
model_native(x, c)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
print(75 * "+")
print("Xformers")
print(75 * "+")
torch.cuda.reset_peak_memory_stats()
with profile(
activities=activities, record_shapes=False, profile_memory=True
) as prof:
with record_function("xformers stats"):
for _ in range(25):
model_efficient_attn(x, c)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
def test01():
# conv1x1 vs linear
from ..util import count_params
conv = nn.Conv2d(3, 32, kernel_size=1).cuda()
print(count_params(conv))
linear = torch.nn.Linear(3, 32).cuda()
print(count_params(linear))
print(conv.weight.shape)
# use same initialization
linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
linear.bias = torch.nn.Parameter(conv.bias)
print(linear.weight.shape)
x = torch.randn(11, 3, 64, 64).cuda()
xr = rearrange(x, "b c h w -> b (h w) c").contiguous()
print(xr.shape)
out_linear = linear(xr)
print(out_linear.mean(), out_linear.shape)
out_conv = conv(x)
print(out_conv.mean(), out_conv.shape)
print("done with test01.\n")
def test02():
# try cosine flash attention
import time
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
print("testing cosine flash attention...")
DIM = 1024
SEQLEN = 4096
BS = 16
print(" softmax (vanilla) first...")
model = BasicTransformerBlock(
dim=DIM,
n_heads=16,
d_head=64,
dropout=0.0,
context_dim=None,
attn_mode="softmax",
).cuda()
try:
x = torch.randn(BS, SEQLEN, DIM).cuda()
tic = time.time()
y = model(x)
toc = time.time()
print(y.shape, toc - tic)
except RuntimeError as e:
# likely oom
print(str(e))
print("\n now flash-cosine...")
model = BasicTransformerBlock(
dim=DIM,
n_heads=16,
d_head=64,
dropout=0.0,
context_dim=None,
attn_mode="flash-cosine",
).cuda()
x = torch.randn(BS, SEQLEN, DIM).cuda()
tic = time.time()
y = model(x)
toc = time.time()
print(y.shape, toc - tic)
print("done with test02.\n")
if __name__ == "__main__":
# test01()
# test02()
# test03()
# benchmark_attn()
benchmark_transformer_blocks()
print("done.")

View File

@@ -3,11 +3,11 @@ from typing import Any, Union
import torch
import torch.nn as nn
from einops import rearrange
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS
from ..lpips.model.model import NLayerDiscriminator, weights_init
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0):

View File

@@ -0,0 +1 @@
vgg.pth

View File

@@ -0,0 +1,23 @@
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,147 @@
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
from collections import namedtuple
import torch
import torch.nn as nn
from torchvision import models
from ..util import get_ckpt_path
class LPIPS(nn.Module):
# Learned perceptual metric
def __init__(self, use_dropout=True):
super().__init__()
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False
def load_from_pretrained(self, name="vgg_lpips"):
ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
self.load_state_dict(
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
)
print("loaded pretrained LPIPS loss from {}".format(ckpt))
@classmethod
def from_pretrained(cls, name="vgg_lpips"):
if name != "vgg_lpips":
raise NotImplementedError
model = cls()
ckpt = get_ckpt_path(name)
model.load_state_dict(
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
)
return model
def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
outs1[kk]
)
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
res = [
spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
for kk in range(len(self.chns))
]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer(
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
)
self.register_buffer(
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
)
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
"""A single linear layer which does a 1x1 conv"""
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = (
[
nn.Dropout(),
]
if (use_dropout)
else []
)
layers += [
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
]
self.model = nn.Sequential(*layers)
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
)
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
def normalize_tensor(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
return x / (norm_factor + eps)
def spatial_average(x, keepdim=True):
return x.mean([2, 3], keepdim=keepdim)

View File

@@ -0,0 +1,58 @@
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------- LICENSE FOR pix2pix --------------------------------
BSD License
For pix2pix software
Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
----------------------------- LICENSE FOR DCGAN --------------------------------
BSD License
For dcgan.torch software
Copyright (c) 2015, Facebook, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,88 @@
import functools
import torch.nn as nn
from ..util import ActNorm
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator as in Pix2Pix
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
if not use_actnorm:
norm_layer = nn.BatchNorm2d
else:
norm_layer = ActNorm
if (
type(norm_layer) == functools.partial
): # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
kw = 4
padw = 1
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True),
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=2,
padding=padw,
bias=use_bias,
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=1,
padding=padw,
bias=use_bias,
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]
sequence += [
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.main(input)

View File

@@ -0,0 +1,128 @@
import hashlib
import os
import requests
import torch
import torch.nn as nn
from tqdm import tqdm
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
CKPT_MAP = {"vgg_lpips": "vgg.pth"}
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
with requests.get(url, stream=True) as r:
total_size = int(r.headers.get("content-length", 0))
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
with open(local_path, "wb") as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)
def md5_hash(path):
with open(path, "rb") as f:
content = f.read()
return hashlib.md5(content).hexdigest()
def get_ckpt_path(name, root, check=False):
assert name in URL_MAP
path = os.path.join(root, CKPT_MAP[name])
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
download(URL_MAP[name], path)
md5 = md5_hash(path)
assert md5 == MD5_MAP[name], md5
return path
class ActNorm(nn.Module):
def __init__(
self, num_features, logdet=False, affine=True, allow_reverse_init=False
):
assert affine
super().__init__()
self.logdet = logdet
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
self.allow_reverse_init = allow_reverse_init
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
def initialize(self, input):
with torch.no_grad():
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
mean = (
flatten.mean(1)
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.permute(1, 0, 2, 3)
)
std = (
flatten.std(1)
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.permute(1, 0, 2, 3)
)
self.loc.data.copy_(-mean)
self.scale.data.copy_(1 / (std + 1e-6))
def forward(self, input, reverse=False):
if reverse:
return self.reverse(input)
if len(input.shape) == 2:
input = input[:, :, None, None]
squeeze = True
else:
squeeze = False
_, _, height, width = input.shape
if self.training and self.initialized.item() == 0:
self.initialize(input)
self.initialized.fill_(1)
h = self.scale * (input + self.loc)
if squeeze:
h = h.squeeze(-1).squeeze(-1)
if self.logdet:
log_abs = torch.log(torch.abs(self.scale))
logdet = height * width * torch.sum(log_abs)
logdet = logdet * torch.ones(input.shape[0]).to(input)
return h, logdet
return h
def reverse(self, output):
if self.training and self.initialized.item() == 0:
if not self.allow_reverse_init:
raise RuntimeError(
"Initializing ActNorm in reverse direction is "
"disabled by default. Use allow_reverse_init=True to enable."
)
else:
self.initialize(output)
self.initialized.fill_(1)
if len(output.shape) == 2:
output = output[:, :, None, None]
squeeze = True
else:
squeeze = False
h = output / self.scale - self.loc
if squeeze:
h = h.squeeze(-1).squeeze(-1)
return h

View File

@@ -0,0 +1,17 @@
import torch
import torch.nn.functional as F
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real))
+ torch.mean(torch.nn.functional.softplus(logits_fake))
)
return d_loss

View File

@@ -1,7 +1,7 @@
from .denoiser import Denoiser
from .discretizer import Discretization
from .loss import StandardDiffusionLoss
from .model import Model, Encoder, Decoder
from .model import Decoder, Encoder, Model
from .openaimodel import UNetModel
from .sampling import BaseDiffusionSampler
from .wrappers import OpenAIWrapper

View File

@@ -1,25 +1,37 @@
import torch
import numpy as np
from abc import abstractmethod
from functools import partial
from ...util import append_zero
import numpy as np
import torch
from ...modules.diffusionmodules.util import make_beta_schedule
from ...util import append_zero
def generate_roughly_equally_spaced_steps(
num_substeps: int, max_step: int
) -> np.ndarray:
return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
class Discretization:
def __call__(self, n, do_append_zero=True, device="cuda", flip=False):
sigmas = self.get_sigmas(n, device)
def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
sigmas = self.get_sigmas(n, device=device)
sigmas = append_zero(sigmas) if do_append_zero else sigmas
return sigmas if not flip else torch.flip(sigmas, (0,))
@abstractmethod
def get_sigmas(self, n, device):
pass
class EDMDiscretization(Discretization):
def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.rho = rho
def get_sigmas(self, n, device):
def get_sigmas(self, n, device="cpu"):
ramp = torch.linspace(0, 1, n, device=device)
min_inv_rho = self.sigma_min ** (1 / self.rho)
max_inv_rho = self.sigma_max ** (1 / self.rho)
@@ -33,8 +45,8 @@ class LegacyDDPMDiscretization(Discretization):
linear_start=0.00085,
linear_end=0.0120,
num_timesteps=1000,
legacy_range=True,
):
super().__init__()
self.num_timesteps = num_timesteps
betas = make_beta_schedule(
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
@@ -42,23 +54,15 @@ class LegacyDDPMDiscretization(Discretization):
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.to_torch = partial(torch.tensor, dtype=torch.float32)
self.legacy_range = legacy_range
def get_sigmas(self, n, device):
def get_sigmas(self, n, device="cpu"):
if n < self.num_timesteps:
c = self.num_timesteps // n
if self.legacy_range:
timesteps = np.asarray(list(range(0, self.num_timesteps, c)))
timesteps += 1 # Legacy LDM Hack
else:
timesteps = np.asarray(list(range(0, self.num_timesteps + 1, c)))
timesteps -= 1
timesteps = timesteps[1:]
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
alphas_cumprod = self.alphas_cumprod[timesteps]
else:
elif n == self.num_timesteps:
alphas_cumprod = self.alphas_cumprod
else:
raise ValueError
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5

View File

@@ -3,9 +3,9 @@ from typing import List, Optional, Union
import torch
import torch.nn as nn
from omegaconf import ListConfig
from taming.modules.losses.lpips import LPIPS
from ...util import append_dims, instantiate_from_config
from ...modules.autoencoding.lpips.loss.lpips import LPIPS
class StandardDiffusionLoss(nn.Module):

View File

@@ -30,5 +30,5 @@ class OpenAIWrapper(IdentityWrapper):
timesteps=t,
context=c.get("crossattn", None),
y=c.get("vector", None),
**kwargs
**kwargs,
)

View File

@@ -1,5 +1,5 @@
import torch
import numpy as np
import torch
class AbstractDistribution:

View File

@@ -212,7 +212,6 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
raise NotImplementedError
model = instantiate_from_config(config.model)
sd = pl_sd["state_dict"]
m, u = model.load_state_dict(sd, strict=False)
@@ -229,3 +228,39 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
model.eval()
return model
def get_checkpoints_path() -> str:
"""
Get the `checkpoints` directory.
This could be in the root of the repository for a working copy,
or in the cwd for other use cases.
"""
this_dir = os.path.dirname(__file__)
candidates = (
os.path.join(this_dir, "checkpoints"),
os.path.join(os.getcwd(), "checkpoints"),
)
for candidate in candidates:
candidate = os.path.abspath(candidate)
if os.path.isdir(candidate):
return candidate
raise FileNotFoundError(f"Could not find SGM checkpoints in {candidates}")
def get_configs_path() -> str:
"""
Get the `configs` directory.
For a working copy, this is the one in the root of the repository,
but for an installed copy, it's in the `sgm` package (see pyproject.toml).
"""
this_dir = os.path.dirname(__file__)
candidates = (
os.path.join(this_dir, "configs"),
os.path.join(this_dir, "..", "configs"),
)
for candidate in candidates:
candidate = os.path.abspath(candidate)
if os.path.isdir(candidate):
return candidate
raise FileNotFoundError(f"Could not find SGM configs in {candidates}")

View File

@@ -0,0 +1,109 @@
import numpy
from PIL import Image
import pytest
from pytest import fixture
import torch
from typing import Tuple
from sgm.inference.api import (
model_specs,
SamplingParams,
SamplingPipeline,
Sampler,
ModelArchitecture,
)
import sgm.inference.helpers as helpers
@pytest.mark.inference
class TestInference:
@fixture(scope="class", params=model_specs.keys())
def pipeline(self, request) -> SamplingPipeline:
pipeline = SamplingPipeline(request.param)
yield pipeline
del pipeline
torch.cuda.empty_cache()
@fixture(
scope="class",
params=[
[ModelArchitecture.SDXL_V1_0_BASE, ModelArchitecture.SDXL_V1_0_REFINER],
[ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER],
],
ids=["SDXL_V1", "SDXL_V0_9"],
)
def sdxl_pipelines(self, request) -> Tuple[SamplingPipeline, SamplingPipeline]:
base_pipeline = SamplingPipeline(request.param[0])
refiner_pipeline = SamplingPipeline(request.param[1])
yield base_pipeline, refiner_pipeline
del base_pipeline
del refiner_pipeline
torch.cuda.empty_cache()
def create_init_image(self, h, w):
image_array = numpy.random.rand(h, w, 3) * 255
image = Image.fromarray(image_array.astype("uint8")).convert("RGB")
return helpers.get_input_image_tensor(image)
@pytest.mark.parametrize("sampler_enum", Sampler)
def test_txt2img(self, pipeline: SamplingPipeline, sampler_enum):
output = pipeline.text_to_image(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="",
samples=1,
)
assert output is not None
@pytest.mark.parametrize("sampler_enum", Sampler)
def test_img2img(self, pipeline: SamplingPipeline, sampler_enum):
output = pipeline.image_to_image(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
image=self.create_init_image(pipeline.specs.height, pipeline.specs.width),
prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="",
samples=1,
)
assert output is not None
@pytest.mark.parametrize("sampler_enum", Sampler)
@pytest.mark.parametrize("use_init_image", [True, False], ids=["img2img", "txt2img"])
def test_sdxl_with_refiner(
self,
sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline],
sampler_enum,
use_init_image,
):
base_pipeline, refiner_pipeline = sdxl_pipelines
if use_init_image:
output = base_pipeline.image_to_image(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
image=self.create_init_image(base_pipeline.specs.height, base_pipeline.specs.width),
prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="",
samples=1,
return_latents=True,
noise_strength=0.15,
)
else:
output = base_pipeline.text_to_image(
params=SamplingParams(sampler=sampler_enum.value, steps=10),
prompt="A professional photograph of an astronaut riding a pig",
negative_prompt="",
samples=1,
return_latents=True,
noise_strength=0.15,
)
assert isinstance(output, (tuple, list))
samples, samples_z = output
assert samples is not None
assert samples_z is not None
refiner_pipeline.refiner(
image=samples_z,
prompt="A professional photograph of an astronaut riding a pig",
params=SamplingParams(sampler=sampler_enum.value, steps=40, img2img_strength=0.15),
negative_prompt="",
samples=1,
)

View File

@@ -0,0 +1,44 @@
import pytest
import torch
from sgm.inference.api import (
SamplingPipeline,
ModelArchitecture,
)
import sgm.inference.helpers as helpers
def get_torch_device(model: torch.nn.Module) -> torch.device:
param = next(model.parameters(), None)
if param is not None:
return param.device
else:
buf = next(model.buffers(), None)
if buf is not None:
return buf.device
else:
raise TypeError("Could not determine device of input model")
@pytest.mark.inference
def test_default_loading():
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1)
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
with pipeline.device_manager.use(pipeline.model.model):
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.model).type == "cuda"
with pipeline.device_manager.use(pipeline.model.conditioner):
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
@pytest.mark.inference
def test_model_swapping():
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1, device=helpers.CudaModelManager(device="cuda", swap_device="cpu"))
assert get_torch_device(pipeline.model.model).type == "cpu"
assert get_torch_device(pipeline.model.conditioner).type == "cpu"
with pipeline.device_manager.use(pipeline.model.model):
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.model).type == "cpu"
with pipeline.device_manager.use(pipeline.model.conditioner):
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cpu"