Compare commits
81 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8cd657656 | ||
|
|
0a4ea360db | ||
|
|
8f41cbc50b | ||
|
|
f87e52e72c | ||
|
|
0ad7de9a5c | ||
|
|
c3147b86db | ||
|
|
1659a1c09b | ||
|
|
37ab71e234 | ||
|
|
e90e953330 | ||
|
|
da40ebad4e | ||
|
|
50364a7d2f | ||
|
|
2cea114cc1 | ||
|
|
734195d1c9 | ||
|
|
854bd4f0df | ||
|
|
e0596f1aca | ||
|
|
ce1576bfca | ||
|
|
1cd0cbaff4 | ||
|
|
863665548f | ||
|
|
e3e4b9d263 | ||
|
|
1aa06e5995 | ||
|
|
998cb122d3 | ||
|
|
31fe459a85 | ||
|
|
abe9ed3d40 | ||
|
|
fbdc58cab9 | ||
|
|
bdbae9948f | ||
|
|
2a532db0e8 | ||
|
|
fba930d400 | ||
|
|
b4b7b644a1 | ||
|
|
c51e4e30c2 | ||
|
|
1e30a2df80 | ||
|
|
9d759324e9 | ||
|
|
a3803c007b | ||
|
|
e6f0e36f5e | ||
|
|
ed0997173f | ||
|
|
f3458e2a9e | ||
|
|
a8b4e89ca1 | ||
|
|
4757f16482 | ||
|
|
059d8e9cd9 | ||
|
|
477d8b9a77 | ||
|
|
45c443b316 | ||
|
|
dea60596fc | ||
|
|
299abbcd90 | ||
|
|
e5d714d304 | ||
|
|
f2fa96b7e5 | ||
|
|
c60c091f4d | ||
|
|
931d7a389a | ||
|
|
e596332148 | ||
|
|
4a3f0f546e | ||
|
|
7934245835 | ||
|
|
1da250906d | ||
|
|
a4ceca6d03 | ||
|
|
68f3f89bd3 | ||
|
|
57862fb4c7 | ||
|
|
ef520df1db | ||
|
|
2897fdc99a | ||
|
|
b5b5680150 | ||
|
|
6f6d3f8716 | ||
|
|
6ecd0a900a | ||
|
|
e25e4c0df1 | ||
|
|
e5dc9669ed | ||
|
|
48904a692d | ||
|
|
5c10deee76 | ||
|
|
89f5413e6d | ||
|
|
ba3e7fed5a | ||
|
|
ea89ce793d | ||
|
|
7b1978e055 | ||
|
|
9d5ace911e | ||
|
|
95b9acc5c6 | ||
|
|
5df4d9893c | ||
|
|
ae18ba3e87 | ||
|
|
061d11d55d | ||
|
|
2796c81a5f | ||
|
|
e9869d7822 | ||
|
|
613af104c6 | ||
|
|
376cec3b0f | ||
|
|
76e549dd94 | ||
|
|
5f0a2fcf48 | ||
|
|
d8a6a97fb0 | ||
|
|
a1af4ac4f1 | ||
|
|
58ddbee3ee | ||
|
|
bec98beff8 |
15
.github/workflows/black.yml
vendored
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
name: Run black
|
||||||
|
on: [pull_request]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: Install venv
|
||||||
|
run: |
|
||||||
|
sudo apt-get -y install python3.10-venv
|
||||||
|
- uses: psf/black@stable
|
||||||
|
with:
|
||||||
|
options: "--check --verbose -l88"
|
||||||
|
src: "./sgm ./scripts ./main.py"
|
||||||
27
.github/workflows/test-build.yaml
vendored
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
name: Build package
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
name: Build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.8", "3.10"]
|
||||||
|
requirements-file: ["pt2", "pt13"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -r requirements/${{ matrix.requirements-file }}.txt
|
||||||
|
pip install .
|
||||||
34
.github/workflows/test-inference.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
name: Test inference
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
name: "Test inference"
|
||||||
|
# This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment
|
||||||
|
if: github.repository == 'stability-ai/generative-models'
|
||||||
|
runs-on: [self-hosted, slurm, g40]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: "Symlink checkpoints"
|
||||||
|
run: ln -s ${{vars.SGM_CHECKPOINTS_PATH}} checkpoints
|
||||||
|
- name: "Setup python"
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
- name: "Install Hatch"
|
||||||
|
run: pip install hatch
|
||||||
|
- name: "Run inference tests"
|
||||||
|
run: hatch run ci:test-inference --junit-xml test-results.xml
|
||||||
|
- name: Surface failing tests
|
||||||
|
if: always()
|
||||||
|
uses: pmeier/pytest-results-action@main
|
||||||
|
with:
|
||||||
|
path: test-results.xml
|
||||||
|
summary: true
|
||||||
|
display-options: fEX
|
||||||
|
fail-on-empty: true
|
||||||
19
.gitignore
vendored
@@ -1,7 +1,16 @@
|
|||||||
.pt2
|
# extensions
|
||||||
.pt2_2
|
|
||||||
.pt13
|
|
||||||
*.egg-info
|
*.egg-info
|
||||||
build
|
*.py[cod]
|
||||||
/outputs
|
|
||||||
|
# envs
|
||||||
|
.pt13
|
||||||
|
.pt2
|
||||||
|
|
||||||
|
# directories
|
||||||
/checkpoints
|
/checkpoints
|
||||||
|
/dist
|
||||||
|
/outputs
|
||||||
|
/build
|
||||||
|
/src
|
||||||
|
/.vscode
|
||||||
|
**/__pycache__/
|
||||||
|
|||||||
1
CODEOWNERS
Normal file
@@ -0,0 +1 @@
|
|||||||
|
.github @Stability-AI/infrastructure
|
||||||
21
LICENSE-CODE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Stability AI
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
286
README.md
Normal file → Executable file
@@ -4,40 +4,188 @@
|
|||||||
|
|
||||||
## News
|
## News
|
||||||
|
|
||||||
|
|
||||||
|
**May 20, 2025**
|
||||||
|
- We are releasing **[Stable Video 4D 2.0 (SV4D 2.0)](https://huggingface.co/stabilityai/sv4d2.0)**, an enhanced video-to-4D diffusion model for high-fidelity novel-view video synthesis and 4D asset generation. For research purposes:
|
||||||
|
- **SV4D 2.0** was trained to generate 48 frames (12 video frames x 4 camera views) at 576x576 resolution, given a 12-frame input video of the same size, ideally consisting of white-background images of a moving object.
|
||||||
|
- Compared to our previous 4D model [SV4D](https://huggingface.co/stabilityai/sv4d), **SV4D 2.0** can generate videos with higher fidelity, sharper details during motion, and better spatio-temporal consistency. It also generalizes much better to real-world videos. Moreover, it does not rely on refernce multi-view of the first frame generated by SV3D, making it more robust to self-occlusions.
|
||||||
|
- To generate longer novel-view videos, we autoregressively generate 12 frames at a time and use the previous generation as conditioning views for the remaining frames.
|
||||||
|
- Please check our [project page](https://sv4d20.github.io), [arxiv paper](https://arxiv.org/pdf/2503.16396) and [video summary](https://www.youtube.com/watch?v=dtqj-s50ynU) for more details.
|
||||||
|
|
||||||
|
**QUICKSTART** :
|
||||||
|
- `python scripts/sampling/simple_video_sample_4d2.py --input_path assets/sv4d_videos/camel.gif --output_folder outputs` (after downloading [sv4d2.safetensors](https://huggingface.co/stabilityai/sv4d2.0) from HuggingFace into `checkpoints/`)
|
||||||
|
|
||||||
|
To run **SV4D 2.0** on a single input video of 21 frames:
|
||||||
|
- Download SV4D 2.0 model (`sv4d2.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d2.0) to `checkpoints/`: `huggingface-cli download stabilityai/sv4d2.0 sv4d2.safetensors --local-dir checkpoints`
|
||||||
|
- Run inference: `python scripts/sampling/simple_video_sample_4d2.py --input_path <path/to/video>`
|
||||||
|
- `input_path` : The input video `<path/to/video>` can be
|
||||||
|
- a single video file in `gif` or `mp4` format, such as `assets/sv4d_videos/camel.gif`, or
|
||||||
|
- a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or
|
||||||
|
- a file name pattern matching images of video frames.
|
||||||
|
- `num_steps` : default is 50, can decrease to it to shorten sampling time.
|
||||||
|
- `elevations_deg` : specified elevations (reletive to input view), default is 0.0 (same as input view).
|
||||||
|
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Clipdrop](https://clipdrop.co/) or [SAM2](https://github.com/facebookresearch/segment-anything-2) before running SV4D.
|
||||||
|
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--encoding_t=1` (of frames encoded at a time) and `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- We also train a 8-view model that generates 5 frames x 8 views at a time (same as SV4D).
|
||||||
|
- Download the model from huggingface: `huggingface-cli download stabilityai/sv4d2.0 sv4d2_8views.safetensors --local-dir checkpoints`
|
||||||
|
- Run inference: `python scripts/sampling/simple_video_sample_4d2.py --model_path checkpoints/sv4d2_8views.safetensors --input_path assets/sv4d_videos/chest.gif --output_folder outputs`
|
||||||
|
- The 5x8 model takes 5 frames of input at a time. But the inference scripts for both model take 21-frame video as input by default (same as SV3D and SV4D), we run the model autoregressively until we generate 21 frames.
|
||||||
|
- Install dependencies before running:
|
||||||
|
```
|
||||||
|
python3.10 -m venv .generativemodels
|
||||||
|
source .generativemodels/bin/activate
|
||||||
|
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # check CUDA version
|
||||||
|
pip3 install -r requirements/pt2.txt
|
||||||
|
pip3 install .
|
||||||
|
pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
|
||||||
|
```
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
**July 24, 2024**
|
||||||
|
- We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes:
|
||||||
|
- **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object.
|
||||||
|
- To generate longer novel-view videos (21 frames), we propose a novel sampling method using SV4D, by first sampling 5 anchor frames and then densely sampling the remaining frames while maintaining temporal consistency.
|
||||||
|
- To run the community-build gradio demo locally, run `python -m scripts.demo.gradio_app_sv4d`.
|
||||||
|
- Please check our [project page](https://sv4d.github.io), [tech report](https://sv4d.github.io/static/sv4d_technical_report.pdf) and [video summary](https://www.youtube.com/watch?v=RBP8vdAWTgk) for more details.
|
||||||
|
|
||||||
|
**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_videos/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [sv4d.safetensors](https://huggingface.co/stabilityai/sv4d) and [sv3d_u.safetensors](https://huggingface.co/stabilityai/sv3d) from HuggingFace into `checkpoints/`)
|
||||||
|
|
||||||
|
To run **SV4D** on a single input video of 21 frames:
|
||||||
|
- Download SV3D models (`sv3d_u.safetensors` and `sv3d_p.safetensors`) from [here](https://huggingface.co/stabilityai/sv3d) and SV4D model (`sv4d.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d) to `checkpoints/`
|
||||||
|
- Run `python scripts/sampling/simple_video_sample_4d.py --input_path <path/to/video>`
|
||||||
|
- `input_path` : The input video `<path/to/video>` can be
|
||||||
|
- a single video file in `gif` or `mp4` format, such as `assets/sv4d_videos/test_video1.mp4`, or
|
||||||
|
- a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or
|
||||||
|
- a file name pattern matching images of video frames.
|
||||||
|
- `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time.
|
||||||
|
- `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p.
|
||||||
|
- `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_videos/test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0`
|
||||||
|
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Clipdrop](https://clipdrop.co/) or [SAM2](https://github.com/facebookresearch/segment-anything-2) before running SV4D.
|
||||||
|
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--encoding_t=1` (of frames encoded at a time) and `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
**March 18, 2024**
|
||||||
|
- We are releasing **[SV3D](https://huggingface.co/stabilityai/sv3d)**, an image-to-video model for novel multi-view synthesis, for research purposes:
|
||||||
|
- **SV3D** was trained to generate 21 frames at resolution 576x576, given 1 context frame of the same size, ideally a white-background image with one object.
|
||||||
|
- **SV3D_u**: This variant generates orbital videos based on single image inputs without camera conditioning..
|
||||||
|
- **SV3D_p**: Extending the capability of **SVD3_u**, this variant accommodates both single images and orbital views allowing for the creation of 3D video along specified camera paths.
|
||||||
|
- We extend the streamlit demo `scripts/demo/video_sampling.py` and the standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models.
|
||||||
|
- Please check our [project page](https://sv3d.github.io), [tech report](https://sv3d.github.io/static/paper.pdf) and [video summary](https://youtu.be/Zqw4-1LcfWg) for more details.
|
||||||
|
|
||||||
|
To run **SV3D_u** on a single image:
|
||||||
|
- Download `sv3d_u.safetensors` from https://huggingface.co/stabilityai/sv3d to `checkpoints/sv3d_u.safetensors`
|
||||||
|
- Run `python scripts/sampling/simple_video_sample.py --input_path <path/to/image.png> --version sv3d_u`
|
||||||
|
|
||||||
|
To run **SV3D_p** on a single image:
|
||||||
|
- Download `sv3d_p.safetensors` from https://huggingface.co/stabilityai/sv3d to `checkpoints/sv3d_p.safetensors`
|
||||||
|
1. Generate static orbit at a specified elevation eg. 10.0 : `python scripts/sampling/simple_video_sample.py --input_path <path/to/image.png> --version sv3d_p --elevations_deg 10.0`
|
||||||
|
2. Generate dynamic orbit at a specified elevations and azimuths: specify sequences of 21 elevations (in degrees) to `elevations_deg` ([-90, 90]), and 21 azimuths (in degrees) to `azimuths_deg` [0, 360] in sorted order from 0 to 360. For example: `python scripts/sampling/simple_video_sample.py --input_path <path/to/image.png> --version sv3d_p --elevations_deg [<list of 21 elevations in degrees>] --azimuths_deg [<list of 21 azimuths in degrees>]`
|
||||||
|
|
||||||
|
To run SVD or SV3D on a streamlit server:
|
||||||
|
`streamlit run scripts/demo/video_sampling.py`
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
**November 28, 2023**
|
||||||
|
- We are releasing SDXL-Turbo, a lightning fast text-to image model.
|
||||||
|
Alongside the model, we release a [technical report](https://stability.ai/research/adversarial-diffusion-distillation)
|
||||||
|
- Usage:
|
||||||
|
- Follow the installation instructions or update the existing environment with `pip install streamlit-keyup`.
|
||||||
|
- Download the [weights](https://huggingface.co/stabilityai/sdxl-turbo) and place them in the `checkpoints/` directory.
|
||||||
|
- Run `streamlit run scripts/demo/turbo.py`.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
**November 21, 2023**
|
||||||
|
- We are releasing Stable Video Diffusion, an image-to-video model, for research purposes:
|
||||||
|
- [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid): This model was trained to generate 14
|
||||||
|
frames at resolution 576x1024 given a context frame of the same size.
|
||||||
|
We use the standard image encoder from SD 2.1, but replace the decoder with a temporally-aware `deflickering decoder`.
|
||||||
|
- [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt): Same architecture as `SVD` but finetuned
|
||||||
|
for 25 frame generation.
|
||||||
|
- You can run the community-build gradio demo locally by running `python -m scripts.demo.gradio_app`.
|
||||||
|
- We provide a streamlit demo `scripts/demo/video_sampling.py` and a standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models.
|
||||||
|
- Alongside the model, we release a [technical report](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets).
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
**July 26, 2023**
|
||||||
|
|
||||||
|
- We are releasing two new open models with a
|
||||||
|
permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file
|
||||||
|
hashes):
|
||||||
|
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version
|
||||||
|
over `SDXL-base-0.9`.
|
||||||
|
- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version
|
||||||
|
over `SDXL-refiner-0.9`.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
**July 4, 2023**
|
||||||
|
|
||||||
|
- A technical report on SDXL is now available [here](https://arxiv.org/abs/2307.01952).
|
||||||
|
|
||||||
**June 22, 2023**
|
**June 22, 2023**
|
||||||
|
|
||||||
|
- We are releasing two new diffusion models for research purposes:
|
||||||
|
- `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The
|
||||||
|
base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip)
|
||||||
|
and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses
|
||||||
|
the OpenCLIP model.
|
||||||
|
- `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is
|
||||||
|
not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
|
||||||
|
|
||||||
- We are releasing two new diffusion models:
|
If you would like to access these models for your research, please apply using one of the following links:
|
||||||
- `SD-XL 0.9-base`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip) and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses the OpenCLIP model.
|
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
|
||||||
- `SD-XL 0.9-refiner`: The refiner has been trained to denoise small noise levels of high quality data and as such is not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
|
and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
|
||||||
|
This means that you can apply for any of the two links - and if you are granted - you can access both.
|
||||||
|
Please log in to your Hugging Face Account with your organization email to request access.
|
||||||
**We plan to do a full release soon (July).**
|
**We plan to do a full release soon (July).**
|
||||||
|
|
||||||
## The codebase
|
## The codebase
|
||||||
|
|
||||||
### General Philosophy
|
### General Philosophy
|
||||||
|
|
||||||
Modularity is king. This repo implements a config-driven approach where we build and combine submodules by calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
|
Modularity is king. This repo implements a config-driven approach where we build and combine submodules by
|
||||||
|
calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
|
||||||
|
|
||||||
### Changelog from the old `ldm` codebase
|
### Changelog from the old `ldm` codebase
|
||||||
|
|
||||||
For training, we use [pytorch-lightning](https://www.pytorchlightning.ai/index.html), but it should be easy to use other training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`, now `DiffusionEngine`) has been cleaned up:
|
For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other
|
||||||
|
training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`,
|
||||||
|
now `DiffusionEngine`) has been cleaned up:
|
||||||
|
|
||||||
- No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial conditionings, and all combinations thereof) in a single class: `GeneralConditioner`, see `sgm/modules/encoders/modules.py`.
|
- No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial
|
||||||
|
conditionings, and all combinations thereof) in a single class: `GeneralConditioner`,
|
||||||
|
see `sgm/modules/encoders/modules.py`.
|
||||||
- We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
|
- We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
|
||||||
samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model.
|
samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model.
|
||||||
- We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable change is probably now the option to train continuous time models):
|
- We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable
|
||||||
* Discrete times models (denoisers) are simply a special case of continuous time models (denoisers); see `sgm/modules/diffusionmodules/denoiser.py`.
|
change is probably now the option to train continuous time models):
|
||||||
* The following features are now independent: weighting of the diffusion loss function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
|
* Discrete times models (denoisers) are simply a special case of continuous time models (denoisers);
|
||||||
|
see `sgm/modules/diffusionmodules/denoiser.py`.
|
||||||
|
* The following features are now independent: weighting of the diffusion loss
|
||||||
|
function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the
|
||||||
|
network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during
|
||||||
|
training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
|
||||||
- Autoencoding models have also been cleaned up.
|
- Autoencoding models have also been cleaned up.
|
||||||
|
|
||||||
## Installation:
|
## Installation:
|
||||||
|
|
||||||
<a name="installation"></a>
|
<a name="installation"></a>
|
||||||
|
|
||||||
#### 1. Clone the repo
|
#### 1. Clone the repo
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
git clone git@github.com:Stability-AI/generative-models.git
|
git clone https://github.com/Stability-AI/generative-models.git
|
||||||
cd generative-models
|
cd generative-models
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -45,43 +193,84 @@ cd generative-models
|
|||||||
|
|
||||||
This is assuming you have navigated to the `generative-models` root after cloning it.
|
This is assuming you have navigated to the `generative-models` root after cloning it.
|
||||||
|
|
||||||
**NOTE:** This is tested under `python3.8` and `python3.10`. For other python versions, you might encounter version conflicts.
|
**NOTE:** This is tested under `python3.10`. For other python versions, you might encounter version conflicts.
|
||||||
|
|
||||||
|
|
||||||
**PyTorch 1.13**
|
|
||||||
|
|
||||||
```shell
|
|
||||||
# install required packages from pypi
|
|
||||||
python3 -m venv .pt1
|
|
||||||
source .pt1/bin/activate
|
|
||||||
pip3 install wheel
|
|
||||||
pip3 install -r requirements_pt13.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
**PyTorch 2.0**
|
**PyTorch 2.0**
|
||||||
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
# install required packages from pypi
|
# install required packages from pypi
|
||||||
python3 -m venv .pt2
|
python3 -m venv .pt2
|
||||||
source .pt2/bin/activate
|
source .pt2/bin/activate
|
||||||
pip3 install wheel
|
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||||
pip3 install -r requirements_pt2.txt
|
pip3 install -r requirements/pt2.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
## Inference:
|
#### 3. Install `sgm`
|
||||||
|
|
||||||
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling in `scripts/demo/sampling.py`. The following models are currently supported:
|
```shell
|
||||||
- [SD-XL 0.9-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)
|
pip3 install .
|
||||||
- [SD-XL 0.9-refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
|
```
|
||||||
- [SD 2.1-512](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.safetensors)
|
|
||||||
- [SD 2.1-768](https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors)
|
#### 4. Install `sdata` for training
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
|
||||||
|
```
|
||||||
|
|
||||||
|
## Packaging
|
||||||
|
|
||||||
|
This repository uses PEP 517 compliant packaging using [Hatch](https://hatch.pypa.io/latest/).
|
||||||
|
|
||||||
|
To build a distributable wheel, install `hatch` and run `hatch build`
|
||||||
|
(specifying `-t wheel` will skip building a sdist, which is not necessary).
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install hatch
|
||||||
|
hatch build -t wheel
|
||||||
|
```
|
||||||
|
|
||||||
|
You will find the built package in `dist/`. You can install the wheel with `pip install dist/*.whl`.
|
||||||
|
|
||||||
|
Note that the package does **not** currently specify dependencies; you will need to install the required packages,
|
||||||
|
depending on your use case and PyTorch version, manually.
|
||||||
|
|
||||||
|
## Inference
|
||||||
|
|
||||||
|
We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling
|
||||||
|
in `scripts/demo/sampling.py`.
|
||||||
|
We provide file hashes for the complete file as well as for only the saved tensors in the file (
|
||||||
|
see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
|
||||||
|
The following models are currently supported:
|
||||||
|
|
||||||
|
- [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||||
|
```
|
||||||
|
File Hash (sha256): 31e35c80fc4829d14f90153f4c74cd59c90b779f6afe05a74cd6120b893f7e5b
|
||||||
|
Tensordata Hash (sha256): 0xd7a9105a900fd52748f20725fe52fe52b507fd36bee4fc107b1550a26e6ee1d7
|
||||||
|
```
|
||||||
|
- [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)
|
||||||
|
```
|
||||||
|
File Hash (sha256): 7440042bbdc8a24813002c09b6b69b64dc90fded4472613437b7f55f9b7d9c5f
|
||||||
|
Tensordata Hash (sha256): 0x1a77d21bebc4b4de78c474a90cb74dc0d2217caf4061971dbfa75ad406b75d81
|
||||||
|
```
|
||||||
|
- [SDXL-base-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)
|
||||||
|
- [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
|
||||||
|
|
||||||
**Weights for SDXL**:
|
**Weights for SDXL**:
|
||||||
|
|
||||||
|
**SDXL-1.0:**
|
||||||
|
The weights of SDXL-1.0 are available (subject to
|
||||||
|
a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here:
|
||||||
|
|
||||||
|
- base model: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/
|
||||||
|
- refiner model: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/
|
||||||
|
|
||||||
|
**SDXL-0.9:**
|
||||||
|
The weights of SDXL-0.9 are available and subject to a [research license](model_licenses/LICENSE-SDXL0.9).
|
||||||
If you would like to access these models for your research, please apply using one of the following links:
|
If you would like to access these models for your research, please apply using one of the following links:
|
||||||
[SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9), and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
|
[SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
|
||||||
|
and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
|
||||||
This means that you can apply for any of the two links - and if you are granted - you can access both.
|
This means that you can apply for any of the two links - and if you are granted - you can access both.
|
||||||
Please log in to your HuggingFace Account with your organization email to request access.
|
Please log in to your Hugging Face Account with your organization email to request access.
|
||||||
|
|
||||||
After obtaining the weights, place them into `checkpoints/`.
|
After obtaining the weights, place them into `checkpoints/`.
|
||||||
Next, start the demo using
|
Next, start the demo using
|
||||||
@@ -100,6 +289,7 @@ not the same as in previous Stable Diffusion 1.x/2.x versions.
|
|||||||
|
|
||||||
To run the script you need to either have a working installation as above or
|
To run the script you need to either have a working installation as above or
|
||||||
try an _experimental_ import using only a minimal amount of packages:
|
try an _experimental_ import using only a minimal amount of packages:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m venv .detect
|
python -m venv .detect
|
||||||
source .detect/bin/activate
|
source .detect/bin/activate
|
||||||
@@ -111,6 +301,7 @@ pip install --no-deps invisible-watermark
|
|||||||
To run the script you need to have a working installation as above. The script
|
To run the script you need to have a working installation as above. The script
|
||||||
is then useable in the following ways (don't forget to activate your
|
is then useable in the following ways (don't forget to activate your
|
||||||
virtual environment beforehand, e.g. `source .pt1/bin/activate`):
|
virtual environment beforehand, e.g. `source .pt1/bin/activate`):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# test a single file
|
# test a single file
|
||||||
python scripts/demo/detect.py <your filename here>
|
python scripts/demo/detect.py <your filename here>
|
||||||
@@ -137,11 +328,21 @@ run
|
|||||||
python main.py --base configs/example_training/toy/mnist_cond.yaml
|
python main.py --base configs/example_training/toy/mnist_cond.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
**NOTE 1:** Using the non-toy-dataset configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml` and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depdending on the used dataset (which is expected to stored in tar-file in the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search for comments containing `USER:` in the respective config.
|
**NOTE 1:** Using the non-toy-dataset
|
||||||
|
configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml`
|
||||||
|
and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the
|
||||||
|
used dataset (which is expected to stored in tar-file in
|
||||||
|
the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search
|
||||||
|
for comments containing `USER:` in the respective config.
|
||||||
|
|
||||||
**NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`, only `pytorch1.13` is supported.
|
**NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for
|
||||||
|
autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`,
|
||||||
|
only `pytorch1.13` is supported.
|
||||||
|
|
||||||
**NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done for the provided text-to-image configs.
|
**NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires
|
||||||
|
retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing
|
||||||
|
the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done
|
||||||
|
for the provided text-to-image configs.
|
||||||
|
|
||||||
### Building New Diffusion Models
|
### Building New Diffusion Models
|
||||||
|
|
||||||
@@ -150,7 +351,8 @@ python main.py --base configs/example_training/toy/mnist_cond.yaml
|
|||||||
The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of
|
The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of
|
||||||
different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model.
|
different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model.
|
||||||
All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free
|
All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free
|
||||||
guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for text-conditioning or `cls` for class-conditioning.
|
guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for
|
||||||
|
text-conditioning or `cls` for class-conditioning.
|
||||||
When computing conditionings, the embedder will get `batch[input_key]` as input.
|
When computing conditionings, the embedder will get `batch[input_key]` as input.
|
||||||
We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated
|
We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated
|
||||||
appropriately.
|
appropriately.
|
||||||
@@ -163,7 +365,8 @@ enough as we plan to experiment with transformer-based diffusion backbones.
|
|||||||
|
|
||||||
#### Loss
|
#### Loss
|
||||||
|
|
||||||
The loss is configured through `loss_config`. For standard diffusion model training, you will have to set `sigma_sampler_config`.
|
The loss is configured through `loss_config`. For standard diffusion model training, you will have to
|
||||||
|
set `sigma_sampler_config`.
|
||||||
|
|
||||||
#### Sampler config
|
#### Sampler config
|
||||||
|
|
||||||
@@ -173,8 +376,9 @@ guidance.
|
|||||||
|
|
||||||
### Dataset Handling
|
### Dataset Handling
|
||||||
|
|
||||||
|
For large scale training we recommend using the data pipelines from
|
||||||
For large scale training we recommend using the datapipelines from our [datapipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement and automatically included when following the steps from the [Installation section](#installation).
|
our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement
|
||||||
|
and automatically included when following the steps from the [Installation section](#installation).
|
||||||
Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
|
Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
|
||||||
data keys/values,
|
data keys/values,
|
||||||
e.g.,
|
e.g.,
|
||||||
|
|||||||
BIN
assets/001_with_eval.png
Normal file
|
After Width: | Height: | Size: 4.0 MiB |
BIN
assets/sv3d.gif
Normal file
|
After Width: | Height: | Size: 1.2 MiB |
BIN
assets/sv4d.gif
Normal file
|
After Width: | Height: | Size: 8.0 MiB |
BIN
assets/sv4d2.gif
Normal file
|
After Width: | Height: | Size: 9.7 MiB |
BIN
assets/sv4d_videos/bear.gif
Normal file
|
After Width: | Height: | Size: 2.2 MiB |
BIN
assets/sv4d_videos/bee.gif
Normal file
|
After Width: | Height: | Size: 638 KiB |
BIN
assets/sv4d_videos/bmx-bumps.gif
Normal file
|
After Width: | Height: | Size: 2.2 MiB |
BIN
assets/sv4d_videos/camel.gif
Normal file
|
After Width: | Height: | Size: 1.9 MiB |
BIN
assets/sv4d_videos/chameleon.gif
Normal file
|
After Width: | Height: | Size: 1.4 MiB |
BIN
assets/sv4d_videos/chest.gif
Normal file
|
After Width: | Height: | Size: 2.2 MiB |
BIN
assets/sv4d_videos/cows.gif
Normal file
|
After Width: | Height: | Size: 1.7 MiB |
BIN
assets/sv4d_videos/dance-twirl.gif
Normal file
|
After Width: | Height: | Size: 1.2 MiB |
BIN
assets/sv4d_videos/flag.gif
Normal file
|
After Width: | Height: | Size: 2.1 MiB |
BIN
assets/sv4d_videos/gear.gif
Normal file
|
After Width: | Height: | Size: 446 KiB |
BIN
assets/sv4d_videos/hike.gif
Normal file
|
After Width: | Height: | Size: 1.6 MiB |
BIN
assets/sv4d_videos/horsejump-low.gif
Normal file
|
After Width: | Height: | Size: 1.4 MiB |
BIN
assets/sv4d_videos/robot.gif
Normal file
|
After Width: | Height: | Size: 946 KiB |
BIN
assets/sv4d_videos/snowboard.gif
Normal file
|
After Width: | Height: | Size: 1.5 MiB |
BIN
assets/sv4d_videos/test_video1.mp4
Normal file
BIN
assets/sv4d_videos/windmill.gif
Normal file
|
After Width: | Height: | Size: 2.4 MiB |
BIN
assets/test_image.png
Normal file
|
After Width: | Height: | Size: 482 KiB |
BIN
assets/tile.gif
Normal file
|
After Width: | Height: | Size: 18 MiB |
BIN
assets/turbo_tile.png
Normal file
|
After Width: | Height: | Size: 2.1 MiB |
@@ -29,25 +29,14 @@ model:
|
|||||||
in_channels: 3
|
in_channels: 3
|
||||||
out_ch: 3
|
out_ch: 3
|
||||||
ch: 128
|
ch: 128
|
||||||
ch_mult: [ 1, 2, 4 ]
|
ch_mult: [1, 2, 4]
|
||||||
num_res_blocks: 4
|
num_res_blocks: 4
|
||||||
attn_resolutions: [ ]
|
attn_resolutions: []
|
||||||
dropout: 0.0
|
dropout: 0.0
|
||||||
|
|
||||||
decoder_config:
|
decoder_config:
|
||||||
target: sgm.modules.diffusionmodules.model.Decoder
|
target: sgm.modules.diffusionmodules.model.Decoder
|
||||||
params:
|
params: ${model.params.encoder_config.params}
|
||||||
attn_type: none
|
|
||||||
double_z: False
|
|
||||||
z_channels: 4
|
|
||||||
resolution: 256
|
|
||||||
in_channels: 3
|
|
||||||
out_ch: 3
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [ 1, 2, 4 ]
|
|
||||||
num_res_blocks: 4
|
|
||||||
attn_resolutions: [ ]
|
|
||||||
dropout: 0.0
|
|
||||||
|
|
||||||
data:
|
data:
|
||||||
target: sgm.data.dataset.StableDataModuleFromConfig
|
target: sgm.data.dataset.StableDataModuleFromConfig
|
||||||
@@ -55,18 +44,18 @@ data:
|
|||||||
train:
|
train:
|
||||||
datapipeline:
|
datapipeline:
|
||||||
urls:
|
urls:
|
||||||
- "DATA-PATH"
|
- DATA-PATH
|
||||||
pipeline_config:
|
pipeline_config:
|
||||||
shardshuffle: 10000
|
shardshuffle: 10000
|
||||||
sample_shuffle: 10000
|
sample_shuffle: 10000
|
||||||
|
|
||||||
decoders:
|
decoders:
|
||||||
- "pil"
|
- pil
|
||||||
|
|
||||||
postprocessors:
|
postprocessors:
|
||||||
- target: sdata.mappers.TorchVisionImageTransforms
|
- target: sdata.mappers.TorchVisionImageTransforms
|
||||||
params:
|
params:
|
||||||
key: 'jpg'
|
key: jpg
|
||||||
transforms:
|
transforms:
|
||||||
- target: torchvision.transforms.Resize
|
- target: torchvision.transforms.Resize
|
||||||
params:
|
params:
|
||||||
|
|||||||
@@ -0,0 +1,105 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 4.5e-6
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
input_key: jpg
|
||||||
|
monitor: val/loss/rec
|
||||||
|
disc_start_iter: 0
|
||||||
|
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Encoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: true
|
||||||
|
z_channels: 8
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Decoder
|
||||||
|
params: ${model.params.encoder_config.params}
|
||||||
|
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
|
||||||
|
loss_config:
|
||||||
|
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
|
||||||
|
params:
|
||||||
|
perceptual_weight: 0.25
|
||||||
|
disc_start: 20001
|
||||||
|
disc_weight: 0.5
|
||||||
|
learn_logvar: True
|
||||||
|
|
||||||
|
regularization_weights:
|
||||||
|
kl_loss: 1.0
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: sgm.data.dataset.StableDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
train:
|
||||||
|
datapipeline:
|
||||||
|
urls:
|
||||||
|
- DATA-PATH
|
||||||
|
pipeline_config:
|
||||||
|
shardshuffle: 10000
|
||||||
|
sample_shuffle: 10000
|
||||||
|
|
||||||
|
decoders:
|
||||||
|
- pil
|
||||||
|
|
||||||
|
postprocessors:
|
||||||
|
- target: sdata.mappers.TorchVisionImageTransforms
|
||||||
|
params:
|
||||||
|
key: jpg
|
||||||
|
transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.ToTensor
|
||||||
|
- target: sdata.mappers.Rescaler
|
||||||
|
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
|
||||||
|
params:
|
||||||
|
h_key: height
|
||||||
|
w_key: width
|
||||||
|
|
||||||
|
loader:
|
||||||
|
batch_size: 8
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
strategy:
|
||||||
|
target: pytorch_lightning.strategies.DDPStrategy
|
||||||
|
params:
|
||||||
|
find_unused_parameters: True
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
metrics_over_trainsteps_checkpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 50000
|
||||||
|
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
enable_autocast: False
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: True
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
devices: 0,
|
||||||
|
limit_val_batches: 50
|
||||||
|
benchmark: True
|
||||||
|
accumulate_grad_batches: 1
|
||||||
|
val_check_interval: 10000
|
||||||
@@ -21,8 +21,6 @@ model:
|
|||||||
params:
|
params:
|
||||||
num_idx: 1000
|
num_idx: 1000
|
||||||
|
|
||||||
weighting_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
|
||||||
scaling_config:
|
scaling_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||||
discretization_config:
|
discretization_config:
|
||||||
@@ -32,7 +30,6 @@ model:
|
|||||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
use_checkpoint: True
|
use_checkpoint: True
|
||||||
use_fp16: True
|
|
||||||
in_channels: 4
|
in_channels: 4
|
||||||
out_channels: 4
|
out_channels: 4
|
||||||
model_channels: 256
|
model_channels: 256
|
||||||
@@ -42,7 +39,6 @@ model:
|
|||||||
num_head_channels: 64
|
num_head_channels: 64
|
||||||
num_classes: sequential
|
num_classes: sequential
|
||||||
adm_in_channels: 1024
|
adm_in_channels: 1024
|
||||||
use_spatial_transformer: true
|
|
||||||
transformer_depth: 1
|
transformer_depth: 1
|
||||||
context_dim: 1024
|
context_dim: 1024
|
||||||
spatial_transformer_attn_type: softmax-xformers
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
@@ -51,32 +47,31 @@ model:
|
|||||||
target: sgm.modules.GeneralConditioner
|
target: sgm.modules.GeneralConditioner
|
||||||
params:
|
params:
|
||||||
emb_models:
|
emb_models:
|
||||||
# crossattn cond
|
|
||||||
- is_trainable: True
|
- is_trainable: True
|
||||||
input_key: cls
|
input_key: cls
|
||||||
ucg_rate: 0.2
|
ucg_rate: 0.2
|
||||||
target: sgm.modules.encoders.modules.ClassEmbedder
|
target: sgm.modules.encoders.modules.ClassEmbedder
|
||||||
params:
|
params:
|
||||||
add_sequence_dim: True # will be used through crossattn then
|
add_sequence_dim: True
|
||||||
embed_dim: 1024
|
embed_dim: 1024
|
||||||
n_classes: 1000
|
n_classes: 1000
|
||||||
# vector cond
|
|
||||||
- is_trainable: False
|
- is_trainable: False
|
||||||
ucg_rate: 0.2
|
ucg_rate: 0.2
|
||||||
input_key: original_size_as_tuple
|
input_key: original_size_as_tuple
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
params:
|
params:
|
||||||
outdim: 256 # multiplied by two
|
outdim: 256
|
||||||
# vector cond
|
|
||||||
- is_trainable: False
|
- is_trainable: False
|
||||||
input_key: crop_coords_top_left
|
input_key: crop_coords_top_left
|
||||||
ucg_rate: 0.2
|
ucg_rate: 0.2
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
params:
|
params:
|
||||||
outdim: 256 # multiplied by two
|
outdim: 256
|
||||||
|
|
||||||
first_stage_config:
|
first_stage_config:
|
||||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
target: sgm.models.autoencoder.AutoencoderKL
|
||||||
params:
|
params:
|
||||||
ckpt_path: CKPT_PATH
|
ckpt_path: CKPT_PATH
|
||||||
embed_dim: 4
|
embed_dim: 4
|
||||||
@@ -99,6 +94,8 @@ model:
|
|||||||
loss_fn_config:
|
loss_fn_config:
|
||||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||||
params:
|
params:
|
||||||
|
loss_weighting_config:
|
||||||
|
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
|
||||||
sigma_sampler_config:
|
sigma_sampler_config:
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
||||||
params:
|
params:
|
||||||
@@ -127,18 +124,18 @@ data:
|
|||||||
datapipeline:
|
datapipeline:
|
||||||
urls:
|
urls:
|
||||||
# USER: adapt this path the root of your custom dataset
|
# USER: adapt this path the root of your custom dataset
|
||||||
- "DATA_PATH"
|
- DATA_PATH
|
||||||
pipeline_config:
|
pipeline_config:
|
||||||
shardshuffle: 10000
|
shardshuffle: 10000
|
||||||
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
|
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
|
||||||
|
|
||||||
decoders:
|
decoders:
|
||||||
- "pil"
|
- pil
|
||||||
|
|
||||||
postprocessors:
|
postprocessors:
|
||||||
- target: sdata.mappers.TorchVisionImageTransforms
|
- target: sdata.mappers.TorchVisionImageTransforms
|
||||||
params:
|
params:
|
||||||
key: 'jpg' # USER: you might wanna adapt this for your custom dataset
|
key: jpg # USER: you might wanna adapt this for your custom dataset
|
||||||
transforms:
|
transforms:
|
||||||
- target: torchvision.transforms.Resize
|
- target: torchvision.transforms.Resize
|
||||||
params:
|
params:
|
||||||
|
|||||||
@@ -5,10 +5,6 @@ model:
|
|||||||
denoiser_config:
|
denoiser_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
params:
|
params:
|
||||||
weighting_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
|
|
||||||
params:
|
|
||||||
sigma_data: 1.0
|
|
||||||
scaling_config:
|
scaling_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
||||||
params:
|
params:
|
||||||
@@ -17,7 +13,6 @@ model:
|
|||||||
network_config:
|
network_config:
|
||||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
use_checkpoint: True
|
|
||||||
in_channels: 3
|
in_channels: 3
|
||||||
out_channels: 3
|
out_channels: 3
|
||||||
model_channels: 32
|
model_channels: 32
|
||||||
@@ -46,6 +41,10 @@ model:
|
|||||||
loss_fn_config:
|
loss_fn_config:
|
||||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||||
params:
|
params:
|
||||||
|
loss_weighting_config:
|
||||||
|
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
||||||
|
params:
|
||||||
|
sigma_data: 1.0
|
||||||
sigma_sampler_config:
|
sigma_sampler_config:
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,6 @@ model:
|
|||||||
denoiser_config:
|
denoiser_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
params:
|
params:
|
||||||
weighting_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
|
|
||||||
params:
|
|
||||||
sigma_data: 1.0
|
|
||||||
scaling_config:
|
scaling_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
||||||
params:
|
params:
|
||||||
@@ -17,7 +13,6 @@ model:
|
|||||||
network_config:
|
network_config:
|
||||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
use_checkpoint: True
|
|
||||||
in_channels: 1
|
in_channels: 1
|
||||||
out_channels: 1
|
out_channels: 1
|
||||||
model_channels: 32
|
model_channels: 32
|
||||||
@@ -32,6 +27,10 @@ model:
|
|||||||
loss_fn_config:
|
loss_fn_config:
|
||||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||||
params:
|
params:
|
||||||
|
loss_weighting_config:
|
||||||
|
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
||||||
|
params:
|
||||||
|
sigma_data: 1.0
|
||||||
sigma_sampler_config:
|
sigma_sampler_config:
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,6 @@ model:
|
|||||||
denoiser_config:
|
denoiser_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
params:
|
params:
|
||||||
weighting_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
|
|
||||||
params:
|
|
||||||
sigma_data: 1.0
|
|
||||||
scaling_config:
|
scaling_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
||||||
params:
|
params:
|
||||||
@@ -17,13 +13,12 @@ model:
|
|||||||
network_config:
|
network_config:
|
||||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
use_checkpoint: True
|
|
||||||
in_channels: 1
|
in_channels: 1
|
||||||
out_channels: 1
|
out_channels: 1
|
||||||
model_channels: 32
|
model_channels: 32
|
||||||
attention_resolutions: [ ]
|
attention_resolutions: []
|
||||||
num_res_blocks: 4
|
num_res_blocks: 4
|
||||||
channel_mult: [ 1, 2, 2 ]
|
channel_mult: [1, 2, 2]
|
||||||
num_head_channels: 32
|
num_head_channels: 32
|
||||||
num_classes: sequential
|
num_classes: sequential
|
||||||
adm_in_channels: 128
|
adm_in_channels: 128
|
||||||
@@ -33,7 +28,7 @@ model:
|
|||||||
params:
|
params:
|
||||||
emb_models:
|
emb_models:
|
||||||
- is_trainable: True
|
- is_trainable: True
|
||||||
input_key: "cls"
|
input_key: cls
|
||||||
ucg_rate: 0.2
|
ucg_rate: 0.2
|
||||||
target: sgm.modules.encoders.modules.ClassEmbedder
|
target: sgm.modules.encoders.modules.ClassEmbedder
|
||||||
params:
|
params:
|
||||||
@@ -46,6 +41,10 @@ model:
|
|||||||
loss_fn_config:
|
loss_fn_config:
|
||||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||||
params:
|
params:
|
||||||
|
loss_weighting_config:
|
||||||
|
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
||||||
|
params:
|
||||||
|
sigma_data: 1.0
|
||||||
sigma_sampler_config:
|
sigma_sampler_config:
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
||||||
|
|
||||||
|
|||||||
@@ -7,8 +7,6 @@ model:
|
|||||||
params:
|
params:
|
||||||
num_idx: 1000
|
num_idx: 1000
|
||||||
|
|
||||||
weighting_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
|
|
||||||
scaling_config:
|
scaling_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
||||||
discretization_config:
|
discretization_config:
|
||||||
@@ -17,13 +15,12 @@ model:
|
|||||||
network_config:
|
network_config:
|
||||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
use_checkpoint: True
|
|
||||||
in_channels: 1
|
in_channels: 1
|
||||||
out_channels: 1
|
out_channels: 1
|
||||||
model_channels: 32
|
model_channels: 32
|
||||||
attention_resolutions: [ ]
|
attention_resolutions: []
|
||||||
num_res_blocks: 4
|
num_res_blocks: 4
|
||||||
channel_mult: [ 1, 2, 2 ]
|
channel_mult: [1, 2, 2]
|
||||||
num_head_channels: 32
|
num_head_channels: 32
|
||||||
num_classes: sequential
|
num_classes: sequential
|
||||||
adm_in_channels: 128
|
adm_in_channels: 128
|
||||||
@@ -33,7 +30,7 @@ model:
|
|||||||
params:
|
params:
|
||||||
emb_models:
|
emb_models:
|
||||||
- is_trainable: True
|
- is_trainable: True
|
||||||
input_key: "cls"
|
input_key: cls
|
||||||
ucg_rate: 0.2
|
ucg_rate: 0.2
|
||||||
target: sgm.modules.encoders.modules.ClassEmbedder
|
target: sgm.modules.encoders.modules.ClassEmbedder
|
||||||
params:
|
params:
|
||||||
@@ -46,6 +43,8 @@ model:
|
|||||||
loss_fn_config:
|
loss_fn_config:
|
||||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||||
params:
|
params:
|
||||||
|
loss_weighting_config:
|
||||||
|
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
||||||
sigma_sampler_config:
|
sigma_sampler_config:
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
||||||
params:
|
params:
|
||||||
|
|||||||
@@ -5,10 +5,6 @@ model:
|
|||||||
denoiser_config:
|
denoiser_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
params:
|
params:
|
||||||
weighting_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
|
|
||||||
params:
|
|
||||||
sigma_data: 1.0
|
|
||||||
scaling_config:
|
scaling_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
||||||
params:
|
params:
|
||||||
@@ -17,7 +13,6 @@ model:
|
|||||||
network_config:
|
network_config:
|
||||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
use_checkpoint: True
|
|
||||||
in_channels: 1
|
in_channels: 1
|
||||||
out_channels: 1
|
out_channels: 1
|
||||||
model_channels: 32
|
model_channels: 32
|
||||||
@@ -25,7 +20,7 @@ model:
|
|||||||
num_res_blocks: 4
|
num_res_blocks: 4
|
||||||
channel_mult: [1, 2, 2]
|
channel_mult: [1, 2, 2]
|
||||||
num_head_channels: 32
|
num_head_channels: 32
|
||||||
num_classes: "sequential"
|
num_classes: sequential
|
||||||
adm_in_channels: 128
|
adm_in_channels: 128
|
||||||
|
|
||||||
conditioner_config:
|
conditioner_config:
|
||||||
@@ -33,7 +28,7 @@ model:
|
|||||||
params:
|
params:
|
||||||
emb_models:
|
emb_models:
|
||||||
- is_trainable: True
|
- is_trainable: True
|
||||||
input_key: "cls"
|
input_key: cls
|
||||||
ucg_rate: 0.2
|
ucg_rate: 0.2
|
||||||
target: sgm.modules.encoders.modules.ClassEmbedder
|
target: sgm.modules.encoders.modules.ClassEmbedder
|
||||||
params:
|
params:
|
||||||
@@ -46,6 +41,11 @@ model:
|
|||||||
loss_fn_config:
|
loss_fn_config:
|
||||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||||
params:
|
params:
|
||||||
|
loss_type: l1
|
||||||
|
loss_weighting_config:
|
||||||
|
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
||||||
|
params:
|
||||||
|
sigma_data: 1.0
|
||||||
sigma_sampler_config:
|
sigma_sampler_config:
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
||||||
|
|
||||||
@@ -62,11 +62,6 @@ model:
|
|||||||
params:
|
params:
|
||||||
scale: 3.0
|
scale: 3.0
|
||||||
|
|
||||||
loss_config:
|
|
||||||
target: sgm.modules.diffusionmodules.StandardDiffusionLoss
|
|
||||||
params:
|
|
||||||
type: l1
|
|
||||||
|
|
||||||
data:
|
data:
|
||||||
target: sgm.data.mnist.MNISTLoader
|
target: sgm.data.mnist.MNISTLoader
|
||||||
params:
|
params:
|
||||||
|
|||||||
@@ -7,10 +7,6 @@ model:
|
|||||||
denoiser_config:
|
denoiser_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
params:
|
params:
|
||||||
weighting_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
|
|
||||||
params:
|
|
||||||
sigma_data: 1.0
|
|
||||||
scaling_config:
|
scaling_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
|
||||||
params:
|
params:
|
||||||
@@ -19,7 +15,6 @@ model:
|
|||||||
network_config:
|
network_config:
|
||||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
use_checkpoint: True
|
|
||||||
in_channels: 1
|
in_channels: 1
|
||||||
out_channels: 1
|
out_channels: 1
|
||||||
model_channels: 32
|
model_channels: 32
|
||||||
@@ -48,6 +43,10 @@ model:
|
|||||||
loss_fn_config:
|
loss_fn_config:
|
||||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||||
params:
|
params:
|
||||||
|
loss_weighting_config:
|
||||||
|
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
|
||||||
|
params:
|
||||||
|
sigma_data: 1.0
|
||||||
sigma_sampler_config:
|
sigma_sampler_config:
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
|
||||||
|
|
||||||
|
|||||||
@@ -10,19 +10,17 @@ model:
|
|||||||
scheduler_config:
|
scheduler_config:
|
||||||
target: sgm.lr_scheduler.LambdaLinearScheduler
|
target: sgm.lr_scheduler.LambdaLinearScheduler
|
||||||
params:
|
params:
|
||||||
warm_up_steps: [ 10000 ]
|
warm_up_steps: [10000]
|
||||||
cycle_lengths: [ 10000000000000 ]
|
cycle_lengths: [10000000000000]
|
||||||
f_start: [ 1.e-6 ]
|
f_start: [1.e-6]
|
||||||
f_max: [ 1. ]
|
f_max: [1.]
|
||||||
f_min: [ 1. ]
|
f_min: [1.]
|
||||||
|
|
||||||
denoiser_config:
|
denoiser_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||||
params:
|
params:
|
||||||
num_idx: 1000
|
num_idx: 1000
|
||||||
|
|
||||||
weighting_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
|
||||||
scaling_config:
|
scaling_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||||
discretization_config:
|
discretization_config:
|
||||||
@@ -32,18 +30,16 @@ model:
|
|||||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
use_checkpoint: True
|
use_checkpoint: True
|
||||||
use_fp16: True
|
|
||||||
in_channels: 4
|
in_channels: 4
|
||||||
out_channels: 4
|
out_channels: 4
|
||||||
model_channels: 320
|
model_channels: 320
|
||||||
attention_resolutions: [ 1, 2, 4 ]
|
attention_resolutions: [1, 2, 4]
|
||||||
num_res_blocks: 2
|
num_res_blocks: 2
|
||||||
channel_mult: [ 1, 2, 4, 4 ]
|
channel_mult: [1, 2, 4, 4]
|
||||||
num_head_channels: 64
|
num_head_channels: 64
|
||||||
num_classes: sequential
|
num_classes: sequential
|
||||||
adm_in_channels: 1792
|
adm_in_channels: 1792
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
use_spatial_transformer: true
|
|
||||||
transformer_depth: 1
|
transformer_depth: 1
|
||||||
context_dim: 768
|
context_dim: 768
|
||||||
spatial_transformer_attn_type: softmax-xformers
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
@@ -52,7 +48,6 @@ model:
|
|||||||
target: sgm.modules.GeneralConditioner
|
target: sgm.modules.GeneralConditioner
|
||||||
params:
|
params:
|
||||||
emb_models:
|
emb_models:
|
||||||
# crossattn cond
|
|
||||||
- is_trainable: True
|
- is_trainable: True
|
||||||
input_key: txt
|
input_key: txt
|
||||||
ucg_rate: 0.1
|
ucg_rate: 0.1
|
||||||
@@ -60,23 +55,23 @@ model:
|
|||||||
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
params:
|
params:
|
||||||
always_return_pooled: True
|
always_return_pooled: True
|
||||||
# vector cond
|
|
||||||
- is_trainable: False
|
- is_trainable: False
|
||||||
ucg_rate: 0.1
|
ucg_rate: 0.1
|
||||||
input_key: original_size_as_tuple
|
input_key: original_size_as_tuple
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
params:
|
params:
|
||||||
outdim: 256 # multiplied by two
|
outdim: 256
|
||||||
# vector cond
|
|
||||||
- is_trainable: False
|
- is_trainable: False
|
||||||
input_key: crop_coords_top_left
|
input_key: crop_coords_top_left
|
||||||
ucg_rate: 0.1
|
ucg_rate: 0.1
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
params:
|
params:
|
||||||
outdim: 256 # multiplied by two
|
outdim: 256
|
||||||
|
|
||||||
first_stage_config:
|
first_stage_config:
|
||||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
target: sgm.models.autoencoder.AutoencoderKL
|
||||||
params:
|
params:
|
||||||
ckpt_path: CKPT_PATH
|
ckpt_path: CKPT_PATH
|
||||||
embed_dim: 4
|
embed_dim: 4
|
||||||
@@ -99,6 +94,8 @@ model:
|
|||||||
loss_fn_config:
|
loss_fn_config:
|
||||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||||
params:
|
params:
|
||||||
|
loss_weighting_config:
|
||||||
|
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
|
||||||
sigma_sampler_config:
|
sigma_sampler_config:
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
||||||
params:
|
params:
|
||||||
@@ -127,18 +124,18 @@ data:
|
|||||||
datapipeline:
|
datapipeline:
|
||||||
urls:
|
urls:
|
||||||
# USER: adapt this path the root of your custom dataset
|
# USER: adapt this path the root of your custom dataset
|
||||||
- "DATA_PATH"
|
- DATA_PATH
|
||||||
pipeline_config:
|
pipeline_config:
|
||||||
shardshuffle: 10000
|
shardshuffle: 10000
|
||||||
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
|
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
|
||||||
|
|
||||||
decoders:
|
decoders:
|
||||||
- "pil"
|
- pil
|
||||||
|
|
||||||
postprocessors:
|
postprocessors:
|
||||||
- target: sdata.mappers.TorchVisionImageTransforms
|
- target: sdata.mappers.TorchVisionImageTransforms
|
||||||
params:
|
params:
|
||||||
key: 'jpg' # USER: you might wanna adapt this for your custom dataset
|
key: jpg # USER: you might wanna adapt this for your custom dataset
|
||||||
transforms:
|
transforms:
|
||||||
- target: torchvision.transforms.Resize
|
- target: torchvision.transforms.Resize
|
||||||
params:
|
params:
|
||||||
|
|||||||
@@ -10,19 +10,17 @@ model:
|
|||||||
scheduler_config:
|
scheduler_config:
|
||||||
target: sgm.lr_scheduler.LambdaLinearScheduler
|
target: sgm.lr_scheduler.LambdaLinearScheduler
|
||||||
params:
|
params:
|
||||||
warm_up_steps: [ 10000 ]
|
warm_up_steps: [10000]
|
||||||
cycle_lengths: [ 10000000000000 ]
|
cycle_lengths: [10000000000000]
|
||||||
f_start: [ 1.e-6 ]
|
f_start: [1.e-6]
|
||||||
f_max: [ 1. ]
|
f_max: [1.]
|
||||||
f_min: [ 1. ]
|
f_min: [1.]
|
||||||
|
|
||||||
denoiser_config:
|
denoiser_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||||
params:
|
params:
|
||||||
num_idx: 1000
|
num_idx: 1000
|
||||||
|
|
||||||
weighting_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
|
||||||
scaling_config:
|
scaling_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||||
discretization_config:
|
discretization_config:
|
||||||
@@ -32,18 +30,16 @@ model:
|
|||||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
use_checkpoint: True
|
use_checkpoint: True
|
||||||
use_fp16: True
|
|
||||||
in_channels: 4
|
in_channels: 4
|
||||||
out_channels: 4
|
out_channels: 4
|
||||||
model_channels: 320
|
model_channels: 320
|
||||||
attention_resolutions: [ 1, 2, 4 ]
|
attention_resolutions: [1, 2, 4]
|
||||||
num_res_blocks: 2
|
num_res_blocks: 2
|
||||||
channel_mult: [ 1, 2, 4, 4 ]
|
channel_mult: [1, 2, 4, 4]
|
||||||
num_head_channels: 64
|
num_head_channels: 64
|
||||||
num_classes: sequential
|
num_classes: sequential
|
||||||
adm_in_channels: 1792
|
adm_in_channels: 1792
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
use_spatial_transformer: true
|
|
||||||
transformer_depth: 1
|
transformer_depth: 1
|
||||||
context_dim: 768
|
context_dim: 768
|
||||||
spatial_transformer_attn_type: softmax-xformers
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
@@ -52,30 +48,30 @@ model:
|
|||||||
target: sgm.modules.GeneralConditioner
|
target: sgm.modules.GeneralConditioner
|
||||||
params:
|
params:
|
||||||
emb_models:
|
emb_models:
|
||||||
# crossattn cond
|
|
||||||
- is_trainable: True
|
- is_trainable: True
|
||||||
input_key: txt
|
input_key: txt
|
||||||
ucg_rate: 0.1
|
ucg_rate: 0.1
|
||||||
|
legacy_ucg_value: ""
|
||||||
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
params:
|
params:
|
||||||
always_return_pooled: True
|
always_return_pooled: True
|
||||||
# vector cond
|
|
||||||
- is_trainable: False
|
- is_trainable: False
|
||||||
ucg_rate: 0.1
|
ucg_rate: 0.1
|
||||||
input_key: original_size_as_tuple
|
input_key: original_size_as_tuple
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
params:
|
params:
|
||||||
outdim: 256 # multiplied by two
|
outdim: 256
|
||||||
# vector cond
|
|
||||||
- is_trainable: False
|
- is_trainable: False
|
||||||
input_key: crop_coords_top_left
|
input_key: crop_coords_top_left
|
||||||
ucg_rate: 0.1
|
ucg_rate: 0.1
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
params:
|
params:
|
||||||
outdim: 256 # multiplied by two
|
outdim: 256
|
||||||
|
|
||||||
first_stage_config:
|
first_stage_config:
|
||||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
target: sgm.models.autoencoder.AutoencoderKL
|
||||||
params:
|
params:
|
||||||
ckpt_path: CKPT_PATH
|
ckpt_path: CKPT_PATH
|
||||||
embed_dim: 4
|
embed_dim: 4
|
||||||
@@ -88,9 +84,9 @@ model:
|
|||||||
in_channels: 3
|
in_channels: 3
|
||||||
out_ch: 3
|
out_ch: 3
|
||||||
ch: 128
|
ch: 128
|
||||||
ch_mult: [ 1, 2, 4, 4 ]
|
ch_mult: [1, 2, 4, 4]
|
||||||
num_res_blocks: 2
|
num_res_blocks: 2
|
||||||
attn_resolutions: [ ]
|
attn_resolutions: []
|
||||||
dropout: 0.0
|
dropout: 0.0
|
||||||
lossconfig:
|
lossconfig:
|
||||||
target: torch.nn.Identity
|
target: torch.nn.Identity
|
||||||
@@ -98,6 +94,8 @@ model:
|
|||||||
loss_fn_config:
|
loss_fn_config:
|
||||||
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
|
||||||
params:
|
params:
|
||||||
|
loss_weighting_config:
|
||||||
|
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
|
||||||
sigma_sampler_config:
|
sigma_sampler_config:
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
||||||
params:
|
params:
|
||||||
@@ -126,19 +124,19 @@ data:
|
|||||||
datapipeline:
|
datapipeline:
|
||||||
urls:
|
urls:
|
||||||
# USER: adapt this path the root of your custom dataset
|
# USER: adapt this path the root of your custom dataset
|
||||||
- "DATA_PATH"
|
- DATA_PATH
|
||||||
pipeline_config:
|
pipeline_config:
|
||||||
shardshuffle: 10000
|
shardshuffle: 10000
|
||||||
sample_shuffle: 10000
|
sample_shuffle: 10000
|
||||||
|
|
||||||
|
|
||||||
decoders:
|
decoders:
|
||||||
- "pil"
|
- pil
|
||||||
|
|
||||||
postprocessors:
|
postprocessors:
|
||||||
- target: sdata.mappers.TorchVisionImageTransforms
|
- target: sdata.mappers.TorchVisionImageTransforms
|
||||||
params:
|
params:
|
||||||
key: 'jpg' # USER: you might wanna adapt this for your custom dataset
|
key: jpg # USER: you might wanna adapt this for your custom dataset
|
||||||
transforms:
|
transforms:
|
||||||
- target: torchvision.transforms.Resize
|
- target: torchvision.transforms.Resize
|
||||||
params:
|
params:
|
||||||
|
|||||||
@@ -1,66 +0,0 @@
|
|||||||
model:
|
|
||||||
target: sgm.models.diffusion.DiffusionEngine
|
|
||||||
params:
|
|
||||||
scale_factor: 0.18215
|
|
||||||
disable_first_stage_autocast: True
|
|
||||||
|
|
||||||
denoiser_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
|
||||||
params:
|
|
||||||
num_idx: 1000
|
|
||||||
|
|
||||||
weighting_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
|
||||||
scaling_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
|
||||||
discretization_config:
|
|
||||||
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
|
||||||
|
|
||||||
network_config:
|
|
||||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
|
||||||
params:
|
|
||||||
use_checkpoint: True
|
|
||||||
use_fp16: True
|
|
||||||
in_channels: 4
|
|
||||||
out_channels: 4
|
|
||||||
model_channels: 320
|
|
||||||
attention_resolutions: [4, 2, 1]
|
|
||||||
num_res_blocks: 2
|
|
||||||
channel_mult: [1, 2, 4, 4]
|
|
||||||
num_head_channels: 64
|
|
||||||
use_spatial_transformer: True
|
|
||||||
use_linear_in_transformer: True
|
|
||||||
transformer_depth: 1
|
|
||||||
context_dim: 1024
|
|
||||||
legacy: False
|
|
||||||
|
|
||||||
conditioner_config:
|
|
||||||
target: sgm.modules.GeneralConditioner
|
|
||||||
params:
|
|
||||||
emb_models:
|
|
||||||
# crossattn cond
|
|
||||||
- is_trainable: False
|
|
||||||
input_key: txt
|
|
||||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
|
||||||
params:
|
|
||||||
freeze: true
|
|
||||||
layer: penultimate
|
|
||||||
|
|
||||||
first_stage_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
|
||||||
params:
|
|
||||||
embed_dim: 4
|
|
||||||
monitor: val/rec_loss
|
|
||||||
ddconfig:
|
|
||||||
double_z: true
|
|
||||||
z_channels: 4
|
|
||||||
resolution: 256
|
|
||||||
in_channels: 3
|
|
||||||
out_ch: 3
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
num_res_blocks: 2
|
|
||||||
attn_resolutions: []
|
|
||||||
dropout: 0.0
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
model:
|
|
||||||
target: sgm.models.diffusion.DiffusionEngine
|
|
||||||
params:
|
|
||||||
scale_factor: 0.18215
|
|
||||||
disable_first_stage_autocast: True
|
|
||||||
|
|
||||||
denoiser_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
|
||||||
params:
|
|
||||||
num_idx: 1000
|
|
||||||
|
|
||||||
weighting_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_weighting.VWeighting
|
|
||||||
scaling_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
|
|
||||||
discretization_config:
|
|
||||||
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
|
||||||
|
|
||||||
network_config:
|
|
||||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
|
||||||
params:
|
|
||||||
use_checkpoint: True
|
|
||||||
use_fp16: True
|
|
||||||
in_channels: 4
|
|
||||||
out_channels: 4
|
|
||||||
model_channels: 320
|
|
||||||
attention_resolutions: [4, 2, 1]
|
|
||||||
num_res_blocks: 2
|
|
||||||
channel_mult: [1, 2, 4, 4]
|
|
||||||
num_head_channels: 64
|
|
||||||
use_spatial_transformer: True
|
|
||||||
use_linear_in_transformer: True
|
|
||||||
transformer_depth: 1
|
|
||||||
context_dim: 1024
|
|
||||||
legacy: False
|
|
||||||
|
|
||||||
conditioner_config:
|
|
||||||
target: sgm.modules.GeneralConditioner
|
|
||||||
params:
|
|
||||||
emb_models:
|
|
||||||
# crossattn cond
|
|
||||||
- is_trainable: False
|
|
||||||
input_key: txt
|
|
||||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
|
||||||
params:
|
|
||||||
freeze: true
|
|
||||||
layer: penultimate
|
|
||||||
|
|
||||||
first_stage_config:
|
|
||||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
|
||||||
params:
|
|
||||||
embed_dim: 4
|
|
||||||
monitor: val/rec_loss
|
|
||||||
ddconfig:
|
|
||||||
double_z: true
|
|
||||||
z_channels: 4
|
|
||||||
resolution: 256
|
|
||||||
in_channels: 3
|
|
||||||
out_ch: 3
|
|
||||||
ch: 128
|
|
||||||
ch_mult: [1, 2, 4, 4]
|
|
||||||
num_res_blocks: 2
|
|
||||||
attn_resolutions: []
|
|
||||||
dropout: 0.0
|
|
||||||
lossconfig:
|
|
||||||
target: torch.nn.Identity
|
|
||||||
@@ -9,8 +9,6 @@ model:
|
|||||||
params:
|
params:
|
||||||
num_idx: 1000
|
num_idx: 1000
|
||||||
|
|
||||||
weighting_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
|
||||||
scaling_config:
|
scaling_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||||
discretization_config:
|
discretization_config:
|
||||||
@@ -29,25 +27,22 @@ model:
|
|||||||
num_res_blocks: 2
|
num_res_blocks: 2
|
||||||
channel_mult: [1, 2, 4]
|
channel_mult: [1, 2, 4]
|
||||||
num_head_channels: 64
|
num_head_channels: 64
|
||||||
use_spatial_transformer: True
|
|
||||||
use_linear_in_transformer: True
|
use_linear_in_transformer: True
|
||||||
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
|
transformer_depth: [1, 2, 10]
|
||||||
context_dim: 2048
|
context_dim: 2048
|
||||||
spatial_transformer_attn_type: softmax-xformers
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
legacy: False
|
|
||||||
|
|
||||||
conditioner_config:
|
conditioner_config:
|
||||||
target: sgm.modules.GeneralConditioner
|
target: sgm.modules.GeneralConditioner
|
||||||
params:
|
params:
|
||||||
emb_models:
|
emb_models:
|
||||||
# crossattn cond
|
|
||||||
- is_trainable: False
|
- is_trainable: False
|
||||||
input_key: txt
|
input_key: txt
|
||||||
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
params:
|
params:
|
||||||
layer: hidden
|
layer: hidden
|
||||||
layer_idx: 11
|
layer_idx: 11
|
||||||
# crossattn and vector cond
|
|
||||||
- is_trainable: False
|
- is_trainable: False
|
||||||
input_key: txt
|
input_key: txt
|
||||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
||||||
@@ -58,27 +53,27 @@ model:
|
|||||||
layer: penultimate
|
layer: penultimate
|
||||||
always_return_pooled: True
|
always_return_pooled: True
|
||||||
legacy: False
|
legacy: False
|
||||||
# vector cond
|
|
||||||
- is_trainable: False
|
- is_trainable: False
|
||||||
input_key: original_size_as_tuple
|
input_key: original_size_as_tuple
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
params:
|
params:
|
||||||
outdim: 256 # multiplied by two
|
outdim: 256
|
||||||
# vector cond
|
|
||||||
- is_trainable: False
|
- is_trainable: False
|
||||||
input_key: crop_coords_top_left
|
input_key: crop_coords_top_left
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
params:
|
params:
|
||||||
outdim: 256 # multiplied by two
|
outdim: 256
|
||||||
# vector cond
|
|
||||||
- is_trainable: False
|
- is_trainable: False
|
||||||
input_key: target_size_as_tuple
|
input_key: target_size_as_tuple
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
params:
|
params:
|
||||||
outdim: 256 # multiplied by two
|
outdim: 256
|
||||||
|
|
||||||
first_stage_config:
|
first_stage_config:
|
||||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
target: sgm.models.autoencoder.AutoencoderKL
|
||||||
params:
|
params:
|
||||||
embed_dim: 4
|
embed_dim: 4
|
||||||
monitor: val/rec_loss
|
monitor: val/rec_loss
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ model:
|
|||||||
params:
|
params:
|
||||||
num_idx: 1000
|
num_idx: 1000
|
||||||
|
|
||||||
weighting_config:
|
|
||||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
|
||||||
scaling_config:
|
scaling_config:
|
||||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||||
discretization_config:
|
discretization_config:
|
||||||
@@ -29,18 +27,15 @@ model:
|
|||||||
num_res_blocks: 2
|
num_res_blocks: 2
|
||||||
channel_mult: [1, 2, 4, 4]
|
channel_mult: [1, 2, 4, 4]
|
||||||
num_head_channels: 64
|
num_head_channels: 64
|
||||||
use_spatial_transformer: True
|
|
||||||
use_linear_in_transformer: True
|
use_linear_in_transformer: True
|
||||||
transformer_depth: 4
|
transformer_depth: 4
|
||||||
context_dim: [1280, 1280, 1280, 1280] # 1280
|
context_dim: [1280, 1280, 1280, 1280]
|
||||||
spatial_transformer_attn_type: softmax-xformers
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
legacy: False
|
|
||||||
|
|
||||||
conditioner_config:
|
conditioner_config:
|
||||||
target: sgm.modules.GeneralConditioner
|
target: sgm.modules.GeneralConditioner
|
||||||
params:
|
params:
|
||||||
emb_models:
|
emb_models:
|
||||||
# crossattn and vector cond
|
|
||||||
- is_trainable: False
|
- is_trainable: False
|
||||||
input_key: txt
|
input_key: txt
|
||||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
||||||
@@ -51,27 +46,27 @@ model:
|
|||||||
freeze: True
|
freeze: True
|
||||||
layer: penultimate
|
layer: penultimate
|
||||||
always_return_pooled: True
|
always_return_pooled: True
|
||||||
# vector cond
|
|
||||||
- is_trainable: False
|
- is_trainable: False
|
||||||
input_key: original_size_as_tuple
|
input_key: original_size_as_tuple
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
params:
|
params:
|
||||||
outdim: 256 # multiplied by two
|
outdim: 256
|
||||||
# vector cond
|
|
||||||
- is_trainable: False
|
- is_trainable: False
|
||||||
input_key: crop_coords_top_left
|
input_key: crop_coords_top_left
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
params:
|
params:
|
||||||
outdim: 256 # multiplied by two
|
outdim: 256
|
||||||
# vector cond
|
|
||||||
- is_trainable: False
|
- is_trainable: False
|
||||||
input_key: aesthetic_score
|
input_key: aesthetic_score
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
params:
|
params:
|
||||||
outdim: 256 # multiplied by one
|
outdim: 256
|
||||||
|
|
||||||
first_stage_config:
|
first_stage_config:
|
||||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
target: sgm.models.autoencoder.AutoencoderKL
|
||||||
params:
|
params:
|
||||||
embed_dim: 4
|
embed_dim: 4
|
||||||
monitor: val/rec_loss
|
monitor: val/rec_loss
|
||||||
|
|||||||
118
configs/inference/sv3d_p.yaml
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||||
|
params:
|
||||||
|
adm_in_channels: 1280
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
merge_strategy: learned_with_images
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- input_key: cond_frames_without_noise
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
params:
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
disable_encoder_autocast: True
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
is_ae: True
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
- input_key: cond_aug
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: polars_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
- input_key: azimuths_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Decoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
106
configs/inference/sv3d_u.yaml
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||||
|
params:
|
||||||
|
adm_in_channels: 256
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
merge_strategy: learned_with_images
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- input_key: cond_frames_without_noise
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
params:
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
disable_encoder_autocast: True
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
is_ae: True
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
- input_key: cond_aug
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Decoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
131
configs/inference/svd.yaml
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||||
|
params:
|
||||||
|
adm_in_channels: 768
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
merge_strategy: learned_with_images
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: cond_frames_without_noise
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
params:
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: fps_id
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: motion_bucket_id
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
disable_encoder_autocast: True
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
is_ae: True
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
- input_key: cond_aug
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Encoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
114
configs/inference/svd_image_decoder.yaml
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||||
|
params:
|
||||||
|
adm_in_channels: 768
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
merge_strategy: learned_with_images
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: cond_frames_without_noise
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
params:
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: fps_id
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: motion_bucket_id
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
disable_encoder_autocast: True
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
is_ae: True
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
- input_key: cond_aug
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
14
main.py
@@ -12,22 +12,18 @@ import pytorch_lightning as pl
|
|||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
import wandb
|
import wandb
|
||||||
from PIL import Image
|
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from natsort import natsorted
|
from natsort import natsorted
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from PIL import Image
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from pytorch_lightning.callbacks import Callback
|
from pytorch_lightning.callbacks import Callback
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
from pytorch_lightning.trainer import Trainer
|
from pytorch_lightning.trainer import Trainer
|
||||||
from pytorch_lightning.utilities import rank_zero_only
|
from pytorch_lightning.utilities import rank_zero_only
|
||||||
|
|
||||||
from sgm.util import (
|
from sgm.util import exists, instantiate_from_config, isheatmap
|
||||||
exists,
|
|
||||||
instantiate_from_config,
|
|
||||||
isheatmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
MULTINODE_HACKS = True
|
MULTINODE_HACKS = True
|
||||||
|
|
||||||
@@ -469,9 +465,8 @@ class ImageLogger(Callback):
|
|||||||
self.log_img(pl_module, batch, batch_idx, split="train")
|
self.log_img(pl_module, batch, batch_idx, split="train")
|
||||||
|
|
||||||
@rank_zero_only
|
@rank_zero_only
|
||||||
# def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
|
||||||
def on_validation_batch_end(
|
def on_validation_batch_end(
|
||||||
self, trainer, pl_module, outputs, batch, batch_idx, **kwargs
|
self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs
|
||||||
):
|
):
|
||||||
if not self.disabled and pl_module.global_step > 0:
|
if not self.disabled and pl_module.global_step > 0:
|
||||||
self.log_img(pl_module, batch, batch_idx, split="val")
|
self.log_img(pl_module, batch, batch_idx, split="val")
|
||||||
@@ -911,11 +906,12 @@ if __name__ == "__main__":
|
|||||||
trainer.test(model, data)
|
trainer.test(model, data)
|
||||||
except RuntimeError as err:
|
except RuntimeError as err:
|
||||||
if MULTINODE_HACKS:
|
if MULTINODE_HACKS:
|
||||||
import requests
|
|
||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
|
device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
|
||||||
hostname = socket.gethostname()
|
hostname = socket.gethostname()
|
||||||
ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
|
ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|||||||
58
model_licenses/LICENSE-SDXL-Turbo
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
STABILITY AI NON-COMMERCIAL RESEARCH COMMUNITY LICENSE AGREEMENT
|
||||||
|
Dated: November 28, 2023
|
||||||
|
|
||||||
|
|
||||||
|
By using or distributing any portion or element of the Models, Software, Software Products or Derivative Works, you agree to be bound by this Agreement.
|
||||||
|
|
||||||
|
|
||||||
|
"Agreement" means this Stable Non-Commercial Research Community License Agreement.
|
||||||
|
|
||||||
|
|
||||||
|
“AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
|
||||||
|
|
||||||
|
|
||||||
|
"Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
|
||||||
|
|
||||||
|
|
||||||
|
“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
|
||||||
|
|
||||||
|
|
||||||
|
"Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
||||||
|
|
||||||
|
|
||||||
|
“Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.
|
||||||
|
|
||||||
|
|
||||||
|
“Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works.
|
||||||
|
|
||||||
|
|
||||||
|
"Stability AI" or "we" means Stability AI Ltd. and its affiliates.
|
||||||
|
|
||||||
|
"Software" means Stability AI’s proprietary software made available under this Agreement.
|
||||||
|
|
||||||
|
|
||||||
|
“Software Products” means the Models, Software and Documentation, individually or in any combination.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
1. License Rights and Redistribution.
|
||||||
|
|
||||||
|
a. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to reproduce the Software Products and produce, reproduce, distribute, and create Derivative Works of the Software Products for Non-Commercial Uses only, respectively.
|
||||||
|
|
||||||
|
b. You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact.
|
||||||
|
|
||||||
|
c. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
|
||||||
|
|
||||||
|
2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
|
||||||
|
|
||||||
|
3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
||||||
|
|
||||||
|
4. Intellectual Property.
|
||||||
|
|
||||||
|
a. No trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works.
|
||||||
|
|
||||||
|
b. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works
|
||||||
|
|
||||||
|
c. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement.
|
||||||
|
|
||||||
|
5. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement.
|
||||||
175
model_licenses/LICENSE-SDXL1.0
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
Copyright (c) 2023 Stability AI CreativeML Open RAIL++-M License dated July 26, 2023
|
||||||
|
|
||||||
|
Section I: PREAMBLE Multimodal generative models are being widely adopted and used, and
|
||||||
|
have the potential to transform the way artists, among other individuals, conceive and
|
||||||
|
benefit from AI or ML technologies as a tool for content creation. Notwithstanding the
|
||||||
|
current and potential benefits that these artifacts can bring to society at large, there
|
||||||
|
are also concerns about potential misuses of them, either due to their technical
|
||||||
|
limitations or ethical considerations. In short, this license strives for both the open
|
||||||
|
and responsible downstream use of the accompanying model. When it comes to the open
|
||||||
|
character, we took inspiration from open source permissive licenses regarding the grant
|
||||||
|
of IP rights. Referring to the downstream responsible use, we added use-based
|
||||||
|
restrictions not permitting the use of the model in very specific scenarios, in order
|
||||||
|
for the licensor to be able to enforce the license in case potential misuses of the
|
||||||
|
Model may occur. At the same time, we strive to promote open and responsible research on
|
||||||
|
generative models for art and content generation. Even though downstream derivative
|
||||||
|
versions of the model could be released under different licensing terms, the latter will
|
||||||
|
always have to include - at minimum - the same use-based restrictions as the ones in the
|
||||||
|
original license (this license). We believe in the intersection between open and
|
||||||
|
responsible AI development; thus, this agreement aims to strike a balance between both
|
||||||
|
in order to enable responsible open-science in the field of AI. This CreativeML Open
|
||||||
|
RAIL++-M License governs the use of the model (and its derivatives) and is informed by
|
||||||
|
the model card associated with the model. NOW THEREFORE, You and Licensor agree as
|
||||||
|
follows: Definitions "License" means the terms and conditions for use, reproduction, and
|
||||||
|
Distribution as defined in this document. "Data" means a collection of information
|
||||||
|
and/or content extracted from the dataset used with the Model, including to train,
|
||||||
|
pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
|
||||||
|
"Output" means the results of operating a Model as embodied in informational content
|
||||||
|
resulting therefrom. "Model" means any accompanying machine-learning based assemblies
|
||||||
|
(including checkpoints), consisting of learnt weights, parameters (including optimizer
|
||||||
|
states), corresponding to the model architecture as embodied in the Complementary
|
||||||
|
Material, that have been trained or tuned, in whole or in part on the Data, using the
|
||||||
|
Complementary Material. "Derivatives of the Model" means all modifications to the Model,
|
||||||
|
works based on the Model, or any other model which is created or initialized by transfer
|
||||||
|
of patterns of the weights, parameters, activations or output of the Model, to the other
|
||||||
|
model, in order to cause the other model to perform similarly to the Model, including -
|
||||||
|
but not limited to - distillation methods entailing the use of intermediate data
|
||||||
|
representations or methods based on the generation of synthetic data by the Model for
|
||||||
|
training the other model. "Complementary Material" means the accompanying source code
|
||||||
|
and scripts used to define, run, load, benchmark or evaluate the Model, and used to
|
||||||
|
prepare data for training or evaluation, if any. This includes any accompanying
|
||||||
|
documentation, tutorials, examples, etc, if any. "Distribution" means any transmission,
|
||||||
|
reproduction, publication or other sharing of the Model or Derivatives of the Model to a
|
||||||
|
third party, including providing the Model as a hosted service made available by
|
||||||
|
electronic or other remote means - e.g. API-based or web access. "Licensor" means the
|
||||||
|
copyright owner or entity authorized by the copyright owner that is granting the
|
||||||
|
License, including the persons or entities that may have rights in the Model and/or
|
||||||
|
distributing the Model. "You" (or "Your") means an individual or Legal Entity exercising
|
||||||
|
permissions granted by this License and/or making use of the Model for whichever purpose
|
||||||
|
and in any field of use, including usage of the Model in an end-use application - e.g.
|
||||||
|
chatbot, translator, image generator. "Third Parties" means individuals or legal
|
||||||
|
entities that are not under common control with Licensor or You. "Contribution" means
|
||||||
|
any work of authorship, including the original version of the Model and any
|
||||||
|
modifications or additions to that Model or Derivatives of the Model thereof, that is
|
||||||
|
intentionally submitted to Licensor for inclusion in the Model by the copyright owner or
|
||||||
|
by an individual or Legal Entity authorized to submit on behalf of the copyright owner.
|
||||||
|
For the purposes of this definition, "submitted" means any form of electronic, verbal,
|
||||||
|
or written communication sent to the Licensor or its representatives, including but not
|
||||||
|
limited to communication on electronic mailing lists, source code control systems, and
|
||||||
|
issue tracking systems that are managed by, or on behalf of, the Licensor for the
|
||||||
|
purpose of discussing and improving the Model, but excluding communication that is
|
||||||
|
conspicuously marked or otherwise designated in writing by the copyright owner as "Not a
|
||||||
|
Contribution." "Contributor" means Licensor and any individual or Legal Entity on behalf
|
||||||
|
of whom a Contribution has been received by Licensor and subsequently incorporated
|
||||||
|
within the Model.
|
||||||
|
|
||||||
|
Section II: INTELLECTUAL PROPERTY RIGHTS Both copyright and patent grants apply to the
|
||||||
|
Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of
|
||||||
|
the Model are subject to additional terms as described in
|
||||||
|
|
||||||
|
Section III. Grant of Copyright License. Subject to the terms and conditions of this
|
||||||
|
License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive,
|
||||||
|
no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly
|
||||||
|
display, publicly perform, sublicense, and distribute the Complementary Material, the
|
||||||
|
Model, and Derivatives of the Model. Grant of Patent License. Subject to the terms and
|
||||||
|
conditions of this License and where and as applicable, each Contributor hereby grants
|
||||||
|
to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this paragraph) patent license to make, have made, use, offer to
|
||||||
|
sell, sell, import, and otherwise transfer the Model and the Complementary Material,
|
||||||
|
where such license applies only to those patent claims licensable by such Contributor
|
||||||
|
that are necessarily infringed by their Contribution(s) alone or by combination of their
|
||||||
|
Contribution(s) with the Model to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a cross-claim or counterclaim
|
||||||
|
in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution
|
||||||
|
incorporated within the Model and/or Complementary Material constitutes direct or
|
||||||
|
contributory patent infringement, then any patent licenses granted to You under this
|
||||||
|
License for the Model and/or Work shall terminate as of the date such litigation is
|
||||||
|
asserted or filed. Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
|
||||||
|
Distribution and Redistribution. You may host for Third Party remote access purposes
|
||||||
|
(e.g. software-as-a-service), reproduce and distribute copies of the Model or
|
||||||
|
Derivatives of the Model thereof in any medium, with or without modifications, provided
|
||||||
|
that You meet the following conditions: Use-based restrictions as referenced in
|
||||||
|
paragraph 5 MUST be included as an enforceable provision by You in any type of legal
|
||||||
|
agreement (e.g. a license) governing the use and/or distribution of the Model or
|
||||||
|
Derivatives of the Model, and You shall give notice to subsequent users You Distribute
|
||||||
|
to, that the Model or Derivatives of the Model are subject to paragraph 5. This
|
||||||
|
provision does not apply to the use of Complementary Material. You must give any Third
|
||||||
|
Party recipients of the Model or Derivatives of the Model a copy of this License; You
|
||||||
|
must cause any modified files to carry prominent notices stating that You changed the
|
||||||
|
files; You must retain all copyright, patent, trademark, and attribution notices
|
||||||
|
excluding those notices that do not pertain to any part of the Model, Derivatives of the
|
||||||
|
Model. You may add Your own copyright statement to Your modifications and may provide
|
||||||
|
additional or different license terms and conditions - respecting paragraph 4.a. - for
|
||||||
|
use, reproduction, or Distribution of Your modifications, or for any such Derivatives of
|
||||||
|
the Model as a whole, provided Your use, reproduction, and Distribution of the Model
|
||||||
|
otherwise complies with the conditions stated in this License. Use-based restrictions.
|
||||||
|
The restrictions set forth in Attachment A are considered Use-based restrictions.
|
||||||
|
Therefore You cannot use the Model and the Derivatives of the Model for the specified
|
||||||
|
restricted uses. You may use the Model subject to this License, including only for
|
||||||
|
lawful purposes and in accordance with the License. Use may include creating any content
|
||||||
|
with, finetuning, updating, running, training, evaluating and/or reparametrizing the
|
||||||
|
Model. You shall require all of Your users who use the Model or a Derivative of the
|
||||||
|
Model to comply with the terms of this paragraph (paragraph 5). The Output You Generate.
|
||||||
|
Except as set forth herein, Licensor claims no rights in the Output You generate using
|
||||||
|
the Model. You are accountable for the Output you generate and its subsequent uses. No
|
||||||
|
use of the output can contravene any provision as stated in the License.
|
||||||
|
|
||||||
|
Section IV: OTHER PROVISIONS Updates and Runtime Restrictions. To the maximum extent
|
||||||
|
permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage
|
||||||
|
of the Model in violation of this License. Trademarks and related. Nothing in this
|
||||||
|
License permits You to make use of Licensors’ trademarks, trade names, logos or to
|
||||||
|
otherwise suggest endorsement or misrepresent the relationship between the parties; and
|
||||||
|
any rights not expressly granted herein are reserved by the Licensors. Disclaimer of
|
||||||
|
Warranty. Unless required by applicable law or agreed to in writing, Licensor provides
|
||||||
|
the Model and the Complementary Material (and each Contributor provides its
|
||||||
|
Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||||
|
express or implied, including, without limitation, any warranties or conditions of
|
||||||
|
TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are
|
||||||
|
solely responsible for determining the appropriateness of using or redistributing the
|
||||||
|
Model, Derivatives of the Model, and the Complementary Material and assume any risks
|
||||||
|
associated with Your exercise of permissions under this License. Limitation of
|
||||||
|
Liability. In no event and under no legal theory, whether in tort (including
|
||||||
|
negligence), contract, or otherwise, unless required by applicable law (such as
|
||||||
|
deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special, incidental, or
|
||||||
|
consequential damages of any character arising as a result of this License or out of the
|
||||||
|
use or inability to use the Model and the Complementary Material (including but not
|
||||||
|
limited to damages for loss of goodwill, work stoppage, computer failure or malfunction,
|
||||||
|
or any and all other commercial damages or losses), even if such Contributor has been
|
||||||
|
advised of the possibility of such damages. Accepting Warranty or Additional Liability.
|
||||||
|
While redistributing the Model, Derivatives of the Model and the Complementary Material
|
||||||
|
thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty,
|
||||||
|
indemnity, or other liability obligations and/or rights consistent with this License.
|
||||||
|
However, in accepting such obligations, You may act only on Your own behalf and on Your
|
||||||
|
sole responsibility, not on behalf of any other Contributor, and only if You agree to
|
||||||
|
indemnify, defend, and hold each Contributor harmless for any liability incurred by, or
|
||||||
|
claims asserted against, such Contributor by reason of your accepting any such warranty
|
||||||
|
or additional liability. If any provision of this License is held to be invalid, illegal
|
||||||
|
or unenforceable, the remaining provisions shall be unaffected thereby and remain valid
|
||||||
|
as if such provision had not been set forth herein.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
Attachment A Use Restrictions
|
||||||
|
You agree not to use the Model or Derivatives of the Model:
|
||||||
|
In any way that violates any applicable national, federal, state, local or
|
||||||
|
international law or regulation; For the purpose of exploiting, harming or attempting to
|
||||||
|
exploit or harm minors in any way; To generate or disseminate verifiably false
|
||||||
|
information and/or content with the purpose of harming others; To generate or
|
||||||
|
disseminate personal identifiable information that can be used to harm an individual; To
|
||||||
|
defame, disparage or otherwise harass others; For fully automated decision making that
|
||||||
|
adversely impacts an individual’s legal rights or otherwise creates or modifies a
|
||||||
|
binding, enforceable obligation; For any use intended to or which has the effect of
|
||||||
|
discriminating against or harming individuals or groups based on online or offline
|
||||||
|
social behavior or known or predicted personal or personality characteristics; To
|
||||||
|
exploit any of the vulnerabilities of a specific group of persons based on their age,
|
||||||
|
social, physical or mental characteristics, in order to materially distort the behavior
|
||||||
|
of a person pertaining to that group in a manner that causes or is likely to cause that
|
||||||
|
person or another person physical or psychological harm; For any use intended to or
|
||||||
|
which has the effect of discriminating against individuals or groups based on legally
|
||||||
|
protected characteristics or categories; To provide medical advice and medical results
|
||||||
|
interpretation; To generate or disseminate information for the purpose to be used for
|
||||||
|
administration of justice, law enforcement, immigration or asylum processes, such as
|
||||||
|
predicting an individual will commit fraud/crime commitment (e.g. by text profiling,
|
||||||
|
drawing causal relationships between assertions made in documents, indiscriminate and
|
||||||
|
arbitrarily-targeted use).
|
||||||
41
model_licenses/LICENSE-SV3D
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
STABILITY AI NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT
|
||||||
|
Dated: March 18, 2024
|
||||||
|
|
||||||
|
"Agreement" means this Stable Non-Commercial Research Community License Agreement.
|
||||||
|
|
||||||
|
“AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
|
||||||
|
|
||||||
|
"Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws, (b) any modifications to a Model, and (c) any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
|
||||||
|
|
||||||
|
“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
|
||||||
|
|
||||||
|
"Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
||||||
|
|
||||||
|
“Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.
|
||||||
|
|
||||||
|
“Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works.
|
||||||
|
|
||||||
|
"Stability AI" or "we" means Stability AI Ltd and its affiliates.
|
||||||
|
|
||||||
|
|
||||||
|
"Software" means Stability AI’s proprietary software made available under this Agreement.
|
||||||
|
|
||||||
|
“Software Products” means the Models, Software and Documentation, individually or in any combination.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
1. License Rights and Redistribution.
|
||||||
|
a. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to use, reproduce, distribute, and create Derivative Works of, the Software Products, in each case for Non-Commercial Uses only.
|
||||||
|
b. You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact.
|
||||||
|
c. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
|
||||||
|
2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
|
||||||
|
3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
||||||
|
4. Intellectual Property.
|
||||||
|
a. No trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works.
|
||||||
|
b. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works
|
||||||
|
c. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement.
|
||||||
|
5. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement.
|
||||||
|
|
||||||
|
6. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the United States and the State of California without regard to choice of law
|
||||||
|
principles.
|
||||||
|
|
||||||
31
model_licenses/LICENSE-SVD
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
STABLE VIDEO DIFFUSION NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT
|
||||||
|
Dated: November 21, 2023
|
||||||
|
|
||||||
|
“AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
|
||||||
|
|
||||||
|
"Agreement" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein.
|
||||||
|
"Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model.
|
||||||
|
“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
|
||||||
|
|
||||||
|
"Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
||||||
|
|
||||||
|
"Stability AI" or "we" means Stability AI Ltd.
|
||||||
|
|
||||||
|
"Software" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement.
|
||||||
|
|
||||||
|
“Software Products” means Software and Documentation.
|
||||||
|
|
||||||
|
By using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
License Rights and Redistribution.
|
||||||
|
Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create Derivative Works of the Software Products for purposes other than commercial or production use.
|
||||||
|
b. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "Stable Video Diffusion is licensed under the Stable Video Diffusion Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified.
|
||||||
|
2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS.
|
||||||
|
3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
||||||
|
3. Intellectual Property.
|
||||||
|
a. No trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products.
|
||||||
|
Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works.
|
||||||
|
If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement.
|
||||||
|
4. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Software Products. Sections 2-4 shall survive the termination of this Agreement.
|
||||||
48
pyproject.toml
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "sgm"
|
||||||
|
dynamic = ["version"]
|
||||||
|
description = "Stability Generative Models"
|
||||||
|
readme = "README.md"
|
||||||
|
license-files = { paths = ["LICENSE-CODE"] }
|
||||||
|
requires-python = ">=3.8"
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/Stability-AI/generative-models"
|
||||||
|
|
||||||
|
[tool.hatch.version]
|
||||||
|
path = "sgm/__init__.py"
|
||||||
|
|
||||||
|
[tool.hatch.build]
|
||||||
|
# This needs to be explicitly set so the configuration files
|
||||||
|
# grafted into the `sgm` directory get included in the wheel's
|
||||||
|
# RECORD file.
|
||||||
|
include = [
|
||||||
|
"sgm",
|
||||||
|
]
|
||||||
|
# The force-include configurations below make Hatch copy
|
||||||
|
# the configs/ directory (containing the various YAML files required
|
||||||
|
# to generatively model) into the source distribution and the wheel.
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.sdist.force-include]
|
||||||
|
"./configs" = "sgm/configs"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel.force-include]
|
||||||
|
"./configs" = "sgm/configs"
|
||||||
|
|
||||||
|
[tool.hatch.envs.ci]
|
||||||
|
skip-install = false
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"pytest"
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.hatch.envs.ci.scripts]
|
||||||
|
test-inference = [
|
||||||
|
"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
|
||||||
|
"pip install -r requirements/pt2.txt",
|
||||||
|
"pytest -v tests/inference/test_inference.py {args}",
|
||||||
|
]
|
||||||
3
pytest.ini
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[pytest]
|
||||||
|
markers =
|
||||||
|
inference: mark as inference test (deselect with '-m "not inference"')
|
||||||
45
requirements/pt2.txt
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
black==23.7.0
|
||||||
|
chardet==5.1.0
|
||||||
|
clip @ git+https://github.com/openai/CLIP.git
|
||||||
|
einops>=0.6.1
|
||||||
|
fairscale>=0.4.13
|
||||||
|
fire>=0.5.0
|
||||||
|
fsspec>=2023.6.0
|
||||||
|
imageio[ffmpeg]
|
||||||
|
imageio[pyav]
|
||||||
|
invisible-watermark>=0.2.0
|
||||||
|
kornia==0.6.9
|
||||||
|
matplotlib>=3.7.2
|
||||||
|
natsort>=8.4.0
|
||||||
|
ninja>=1.11.1
|
||||||
|
numpy==2.1
|
||||||
|
omegaconf>=2.3.0
|
||||||
|
onnxruntime
|
||||||
|
open-clip-torch>=2.20.0
|
||||||
|
opencv-python==4.6.0.66
|
||||||
|
pandas>=2.0.3
|
||||||
|
pillow>=9.5.0
|
||||||
|
pudb>=2022.1.3
|
||||||
|
pytorch-lightning==2.0.1
|
||||||
|
pyyaml>=6.0.1
|
||||||
|
rembg
|
||||||
|
scipy>=1.10.1
|
||||||
|
streamlit>=0.73.1
|
||||||
|
tensorboardx==2.6
|
||||||
|
timm>=0.9.2
|
||||||
|
tokenizers==0.12.1
|
||||||
|
torch>=2.0.1
|
||||||
|
torchaudio>=2.0.2
|
||||||
|
torchdata==0.6.1
|
||||||
|
torchmetrics>=1.0.1
|
||||||
|
torchvision>=0.15.2
|
||||||
|
tqdm>=4.65.0
|
||||||
|
transformers==4.19.1
|
||||||
|
triton==2.0.0
|
||||||
|
urllib3<1.27,>=1.25.4
|
||||||
|
wandb>=0.15.6
|
||||||
|
webdataset>=0.2.33
|
||||||
|
wheel>=0.41.0
|
||||||
|
xformers>=0.0.20
|
||||||
|
gradio
|
||||||
|
streamlit-keyup==0.2.0
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
omegaconf
|
|
||||||
einops
|
|
||||||
fire
|
|
||||||
tqdm
|
|
||||||
pillow
|
|
||||||
numpy
|
|
||||||
webdataset>=0.2.33
|
|
||||||
--extra-index-url https://download.pytorch.org/whl/cu117
|
|
||||||
torch==1.13.1+cu117
|
|
||||||
xformers==0.0.16
|
|
||||||
torchaudio==0.13.1
|
|
||||||
torchvision==0.14.1+cu117
|
|
||||||
torchmetrics
|
|
||||||
opencv-python==4.6.0.66
|
|
||||||
fairscale
|
|
||||||
pytorch-lightning==1.8.5
|
|
||||||
fsspec
|
|
||||||
kornia==0.6.9
|
|
||||||
matplotlib
|
|
||||||
natsort
|
|
||||||
tensorboardx==2.5.1
|
|
||||||
open-clip-torch
|
|
||||||
chardet
|
|
||||||
scipy
|
|
||||||
pandas
|
|
||||||
pudb
|
|
||||||
pyyaml
|
|
||||||
urllib3<1.27,>=1.25.4
|
|
||||||
streamlit>=0.73.1
|
|
||||||
timm
|
|
||||||
tokenizers==0.12.1
|
|
||||||
torchdata==0.5.1
|
|
||||||
transformers==4.19.1
|
|
||||||
onnx<=1.12.0
|
|
||||||
triton
|
|
||||||
wandb
|
|
||||||
invisible-watermark
|
|
||||||
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
|
||||||
-e git+https://github.com/openai/CLIP.git@main#egg=clip
|
|
||||||
-e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
|
|
||||||
-e .
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
omegaconf
|
|
||||||
einops
|
|
||||||
fire
|
|
||||||
tqdm
|
|
||||||
pillow
|
|
||||||
numpy
|
|
||||||
webdataset>=0.2.33
|
|
||||||
ninja
|
|
||||||
torch
|
|
||||||
matplotlib
|
|
||||||
torchaudio>=2.0.2
|
|
||||||
torchmetrics
|
|
||||||
torchvision>=0.15.2
|
|
||||||
opencv-python==4.6.0.66
|
|
||||||
fairscale
|
|
||||||
pytorch-lightning==2.0.1
|
|
||||||
fire
|
|
||||||
fsspec
|
|
||||||
kornia==0.6.9
|
|
||||||
natsort
|
|
||||||
open-clip-torch
|
|
||||||
chardet==5.1.0
|
|
||||||
tensorboardx==2.6
|
|
||||||
pandas
|
|
||||||
pudb
|
|
||||||
pyyaml
|
|
||||||
urllib3<1.27,>=1.25.4
|
|
||||||
scipy
|
|
||||||
streamlit>=0.73.1
|
|
||||||
timm
|
|
||||||
tokenizers==0.12.1
|
|
||||||
transformers==4.19.1
|
|
||||||
triton==2.0.0
|
|
||||||
torchdata==0.6.1
|
|
||||||
wandb
|
|
||||||
invisible-watermark
|
|
||||||
xformers
|
|
||||||
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
|
||||||
-e git+https://github.com/openai/CLIP.git@main#egg=clip
|
|
||||||
-e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
|
|
||||||
-e .
|
|
||||||
0
scripts/__init__.py
Normal file
0
scripts/demo/__init__.py
Normal file
@@ -83,7 +83,7 @@ class GetWatermarkMatch:
|
|||||||
def __call__(self, x: np.ndarray) -> np.ndarray:
|
def __call__(self, x: np.ndarray) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Detects the number of matching bits the predefined watermark with one
|
Detects the number of matching bits the predefined watermark with one
|
||||||
or multiple images. Images should be in cv2 format, e.g. h x w x c.
|
or multiple images. Images should be in cv2 format, e.g. h x w x c BGR.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: ([B], h w, c) in range [0, 255]
|
x: ([B], h w, c) in range [0, 255]
|
||||||
@@ -94,7 +94,6 @@ class GetWatermarkMatch:
|
|||||||
squeeze = len(x.shape) == 3
|
squeeze = len(x.shape) == 3
|
||||||
if squeeze:
|
if squeeze:
|
||||||
x = x[None, ...]
|
x = x[None, ...]
|
||||||
x = np.flip(x, axis=-1)
|
|
||||||
|
|
||||||
bs = x.shape[0]
|
bs = x.shape[0]
|
||||||
detected = np.empty((bs, self.num_bits), dtype=bool)
|
detected = np.empty((bs, self.num_bits), dtype=bool)
|
||||||
|
|||||||
59
scripts/demo/discretization.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from sgm.modules.diffusionmodules.discretizer import Discretization
|
||||||
|
|
||||||
|
|
||||||
|
class Img2ImgDiscretizationWrapper:
|
||||||
|
"""
|
||||||
|
wraps a discretizer, and prunes the sigmas
|
||||||
|
params:
|
||||||
|
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, discretization: Discretization, strength: float = 1.0):
|
||||||
|
self.discretization = discretization
|
||||||
|
self.strength = strength
|
||||||
|
assert 0.0 <= self.strength <= 1.0
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
# sigmas start large first, and decrease then
|
||||||
|
sigmas = self.discretization(*args, **kwargs)
|
||||||
|
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||||
|
sigmas = torch.flip(sigmas, (0,))
|
||||||
|
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
||||||
|
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
||||||
|
sigmas = torch.flip(sigmas, (0,))
|
||||||
|
print(f"sigmas after pruning: ", sigmas)
|
||||||
|
return sigmas
|
||||||
|
|
||||||
|
|
||||||
|
class Txt2NoisyDiscretizationWrapper:
|
||||||
|
"""
|
||||||
|
wraps a discretizer, and prunes the sigmas
|
||||||
|
params:
|
||||||
|
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, discretization: Discretization, strength: float = 0.0, original_steps=None
|
||||||
|
):
|
||||||
|
self.discretization = discretization
|
||||||
|
self.strength = strength
|
||||||
|
self.original_steps = original_steps
|
||||||
|
assert 0.0 <= self.strength <= 1.0
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
# sigmas start large first, and decrease then
|
||||||
|
sigmas = self.discretization(*args, **kwargs)
|
||||||
|
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||||
|
sigmas = torch.flip(sigmas, (0,))
|
||||||
|
if self.original_steps is None:
|
||||||
|
steps = len(sigmas)
|
||||||
|
else:
|
||||||
|
steps = self.original_steps + 1
|
||||||
|
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
|
||||||
|
sigmas = sigmas[prune_index:]
|
||||||
|
print("prune index:", prune_index)
|
||||||
|
sigmas = torch.flip(sigmas, (0,))
|
||||||
|
print(f"sigmas after pruning: ", sigmas)
|
||||||
|
return sigmas
|
||||||
310
scripts/demo/gradio_app.py
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
# Adding this at the very top of app.py to make 'generative-models' directory discoverable
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), "generative-models"))
|
||||||
|
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import uuid
|
||||||
|
from glob import glob
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import gradio as gr
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from fire import Fire
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.transforms import ToTensor
|
||||||
|
|
||||||
|
from scripts.sampling.simple_video_sample import (
|
||||||
|
get_batch,
|
||||||
|
get_unique_embedder_keys_from_conditioner,
|
||||||
|
load_model,
|
||||||
|
)
|
||||||
|
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||||
|
from sgm.inference.helpers import embed_watermark
|
||||||
|
from sgm.util import default, instantiate_from_config
|
||||||
|
|
||||||
|
# To download all svd models
|
||||||
|
# hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt", filename="svd_xt.safetensors", local_dir="checkpoints")
|
||||||
|
# hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid", filename="svd.safetensors", local_dir="checkpoints")
|
||||||
|
# hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt-1-1", filename="svd_xt_1_1.safetensors", local_dir="checkpoints")
|
||||||
|
|
||||||
|
|
||||||
|
# Define the repo, local directory and filename
|
||||||
|
repo_id = "stabilityai/stable-video-diffusion-img2vid-xt-1-1" # replace with "stabilityai/stable-video-diffusion-img2vid-xt" or "stabilityai/stable-video-diffusion-img2vid" for other models
|
||||||
|
filename = "svd_xt_1_1.safetensors" # replace with "svd_xt.safetensors" or "svd.safetensors" for other models
|
||||||
|
local_dir = "checkpoints"
|
||||||
|
local_file_path = os.path.join(local_dir, filename)
|
||||||
|
|
||||||
|
# Check if the file already exists
|
||||||
|
if not os.path.exists(local_file_path):
|
||||||
|
# If the file doesn't exist, download it
|
||||||
|
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
|
||||||
|
print("File downloaded.")
|
||||||
|
else:
|
||||||
|
print("File already exists. No need to download.")
|
||||||
|
|
||||||
|
|
||||||
|
version = "svd_xt_1_1" # replace with 'svd_xt' or 'svd' for other models
|
||||||
|
device = "cuda"
|
||||||
|
max_64_bit_int = 2**63 - 1
|
||||||
|
|
||||||
|
if version == "svd_xt_1_1":
|
||||||
|
num_frames = 25
|
||||||
|
num_steps = 30
|
||||||
|
model_config = "scripts/sampling/configs/svd_xt_1_1.yaml"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Version {version} does not exist.")
|
||||||
|
|
||||||
|
model, filter = load_model(
|
||||||
|
model_config,
|
||||||
|
device,
|
||||||
|
num_frames,
|
||||||
|
num_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
randomize_seed: bool = True,
|
||||||
|
motion_bucket_id: int = 127,
|
||||||
|
fps_id: int = 6,
|
||||||
|
version: str = "svd_xt_1_1",
|
||||||
|
cond_aug: float = 0.02,
|
||||||
|
decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||||
|
device: str = "cuda",
|
||||||
|
output_folder: str = "outputs",
|
||||||
|
progress=gr.Progress(track_tqdm=True),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
|
||||||
|
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
|
||||||
|
"""
|
||||||
|
fps_id = int(fps_id) # casting float slider values to int)
|
||||||
|
if randomize_seed:
|
||||||
|
seed = random.randint(0, max_64_bit_int)
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
path = Path(input_path)
|
||||||
|
all_img_paths = []
|
||||||
|
if path.is_file():
|
||||||
|
if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
|
||||||
|
all_img_paths = [input_path]
|
||||||
|
else:
|
||||||
|
raise ValueError("Path is not valid image file.")
|
||||||
|
elif path.is_dir():
|
||||||
|
all_img_paths = sorted(
|
||||||
|
[
|
||||||
|
f
|
||||||
|
for f in path.iterdir()
|
||||||
|
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if len(all_img_paths) == 0:
|
||||||
|
raise ValueError("Folder does not contain any images.")
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
for input_img_path in all_img_paths:
|
||||||
|
with Image.open(input_img_path) as image:
|
||||||
|
if image.mode == "RGBA":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
w, h = image.size
|
||||||
|
|
||||||
|
if h % 64 != 0 or w % 64 != 0:
|
||||||
|
width, height = map(lambda x: x - x % 64, (w, h))
|
||||||
|
image = image.resize((width, height))
|
||||||
|
print(
|
||||||
|
f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
|
||||||
|
)
|
||||||
|
|
||||||
|
image = ToTensor()(image)
|
||||||
|
image = image * 2.0 - 1.0
|
||||||
|
|
||||||
|
image = image.unsqueeze(0).to(device)
|
||||||
|
H, W = image.shape[2:]
|
||||||
|
assert image.shape[1] == 3
|
||||||
|
F = 8
|
||||||
|
C = 4
|
||||||
|
shape = (num_frames, C, H // F, W // F)
|
||||||
|
if (H, W) != (576, 1024):
|
||||||
|
print(
|
||||||
|
"WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
|
||||||
|
)
|
||||||
|
if motion_bucket_id > 255:
|
||||||
|
print(
|
||||||
|
"WARNING: High motion bucket! This may lead to suboptimal performance."
|
||||||
|
)
|
||||||
|
|
||||||
|
if fps_id < 5:
|
||||||
|
print("WARNING: Small fps value! This may lead to suboptimal performance.")
|
||||||
|
|
||||||
|
if fps_id > 30:
|
||||||
|
print("WARNING: Large fps value! This may lead to suboptimal performance.")
|
||||||
|
|
||||||
|
value_dict = {}
|
||||||
|
value_dict["motion_bucket_id"] = motion_bucket_id
|
||||||
|
value_dict["fps_id"] = fps_id
|
||||||
|
value_dict["cond_aug"] = cond_aug
|
||||||
|
value_dict["cond_frames_without_noise"] = image
|
||||||
|
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
|
||||||
|
value_dict["cond_aug"] = cond_aug
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with torch.autocast(device):
|
||||||
|
batch, batch_uc = get_batch(
|
||||||
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
|
value_dict,
|
||||||
|
[1, num_frames],
|
||||||
|
T=num_frames,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||||
|
batch,
|
||||||
|
batch_uc=batch_uc,
|
||||||
|
force_uc_zero_embeddings=[
|
||||||
|
"cond_frames",
|
||||||
|
"cond_frames_without_noise",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for k in ["crossattn", "concat"]:
|
||||||
|
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
|
||||||
|
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
|
||||||
|
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
|
||||||
|
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
|
||||||
|
|
||||||
|
randn = torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
additional_model_inputs = {}
|
||||||
|
additional_model_inputs["image_only_indicator"] = torch.zeros(
|
||||||
|
2, num_frames
|
||||||
|
).to(device)
|
||||||
|
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
|
||||||
|
|
||||||
|
def denoiser(input, sigma, c):
|
||||||
|
return model.denoiser(
|
||||||
|
model.model, input, sigma, c, **additional_model_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
|
||||||
|
model.en_and_decode_n_samples_a_time = decoding_t
|
||||||
|
samples_x = model.decode_first_stage(samples_z)
|
||||||
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
|
||||||
|
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
|
||||||
|
writer = cv2.VideoWriter(
|
||||||
|
video_path,
|
||||||
|
cv2.VideoWriter_fourcc(*"mp4v"),
|
||||||
|
fps_id + 1,
|
||||||
|
(samples.shape[-1], samples.shape[-2]),
|
||||||
|
)
|
||||||
|
|
||||||
|
samples = embed_watermark(samples)
|
||||||
|
samples = filter(samples)
|
||||||
|
vid = (
|
||||||
|
(rearrange(samples, "t c h w -> t h w c") * 255)
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
.astype(np.uint8)
|
||||||
|
)
|
||||||
|
for frame in vid:
|
||||||
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||||
|
writer.write(frame)
|
||||||
|
writer.release()
|
||||||
|
|
||||||
|
return video_path, seed
|
||||||
|
|
||||||
|
|
||||||
|
def resize_image(image_path, output_size=(1024, 576)):
|
||||||
|
image = Image.open(image_path)
|
||||||
|
# Calculate aspect ratios
|
||||||
|
target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
|
||||||
|
image_aspect = image.width / image.height # Aspect ratio of the original image
|
||||||
|
|
||||||
|
# Resize then crop if the original image is larger
|
||||||
|
if image_aspect > target_aspect:
|
||||||
|
# Resize the image to match the target height, maintaining aspect ratio
|
||||||
|
new_height = output_size[1]
|
||||||
|
new_width = int(new_height * image_aspect)
|
||||||
|
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
|
||||||
|
# Calculate coordinates for cropping
|
||||||
|
left = (new_width - output_size[0]) / 2
|
||||||
|
top = 0
|
||||||
|
right = (new_width + output_size[0]) / 2
|
||||||
|
bottom = output_size[1]
|
||||||
|
else:
|
||||||
|
# Resize the image to match the target width, maintaining aspect ratio
|
||||||
|
new_width = output_size[0]
|
||||||
|
new_height = int(new_width / image_aspect)
|
||||||
|
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
|
||||||
|
# Calculate coordinates for cropping
|
||||||
|
left = 0
|
||||||
|
top = (new_height - output_size[1]) / 2
|
||||||
|
right = output_size[0]
|
||||||
|
bottom = (new_height + output_size[1]) / 2
|
||||||
|
|
||||||
|
# Crop the image
|
||||||
|
cropped_image = resized_image.crop((left, top, right, bottom))
|
||||||
|
|
||||||
|
return cropped_image
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
gr.Markdown(
|
||||||
|
"""# Community demo for Stable Video Diffusion - Img2Vid - XT ([model](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt), [paper](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets))
|
||||||
|
#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/blob/main/LICENSE)): generate `4s` vid from a single image at (`25 frames` at `6 fps`). Generation takes ~60s in an A100. [Join the waitlist for Stability's upcoming web experience](https://stability.ai/contact).
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
image = gr.Image(label="Upload your image", type="filepath")
|
||||||
|
generate_btn = gr.Button("Generate")
|
||||||
|
video = gr.Video()
|
||||||
|
with gr.Accordion("Advanced options", open=False):
|
||||||
|
seed = gr.Slider(
|
||||||
|
label="Seed",
|
||||||
|
value=42,
|
||||||
|
randomize=True,
|
||||||
|
minimum=0,
|
||||||
|
maximum=max_64_bit_int,
|
||||||
|
step=1,
|
||||||
|
)
|
||||||
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
||||||
|
motion_bucket_id = gr.Slider(
|
||||||
|
label="Motion bucket id",
|
||||||
|
info="Controls how much motion to add/remove from the image",
|
||||||
|
value=127,
|
||||||
|
minimum=1,
|
||||||
|
maximum=255,
|
||||||
|
)
|
||||||
|
fps_id = gr.Slider(
|
||||||
|
label="Frames per second",
|
||||||
|
info="The length of your video in seconds will be 25/fps",
|
||||||
|
value=6,
|
||||||
|
minimum=5,
|
||||||
|
maximum=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
|
||||||
|
generate_btn.click(
|
||||||
|
fn=sample,
|
||||||
|
inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id],
|
||||||
|
outputs=[video, seed],
|
||||||
|
api_name="video",
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo.queue(max_size=20)
|
||||||
|
demo.launch(share=True)
|
||||||
496
scripts/demo/gradio_app_sv4d.py
Normal file
@@ -0,0 +1,496 @@
|
|||||||
|
# Adding this at the very top of app.py to make 'generative-models' directory discoverable
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), "generative-models"))
|
||||||
|
|
||||||
|
from glob import glob
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
|
||||||
|
from scripts.demo.sv4d_helpers import (
|
||||||
|
decode_latents,
|
||||||
|
load_model,
|
||||||
|
initial_model_load,
|
||||||
|
read_video,
|
||||||
|
run_img2vid,
|
||||||
|
prepare_inputs,
|
||||||
|
do_sample_per_step,
|
||||||
|
sample_sv3d,
|
||||||
|
save_video,
|
||||||
|
preprocess_video,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# the tmp path, if /tmp/gradio is not writable, change it to a writable path
|
||||||
|
# os.environ["GRADIO_TEMP_DIR"] = "gradio_tmp"
|
||||||
|
|
||||||
|
version = "sv4d" # replace with 'sv3d_p' or 'sv3d_u' for other models
|
||||||
|
|
||||||
|
# Define the repo, local directory and filename
|
||||||
|
repo_id = "stabilityai/sv4d"
|
||||||
|
filename = f"{version}.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors"
|
||||||
|
local_dir = "checkpoints"
|
||||||
|
local_ckpt_path = os.path.join(local_dir, filename)
|
||||||
|
|
||||||
|
# Check if the file already exists
|
||||||
|
if not os.path.exists(local_ckpt_path):
|
||||||
|
# If the file doesn't exist, download it
|
||||||
|
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
|
||||||
|
print("File downloaded. (sv4d)")
|
||||||
|
else:
|
||||||
|
print("File already exists. No need to download. (sv4d)")
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
max_64_bit_int = 2**63 - 1
|
||||||
|
|
||||||
|
num_frames = 21
|
||||||
|
num_steps = 20
|
||||||
|
model_config = f"scripts/sampling/configs/{version}.yaml"
|
||||||
|
|
||||||
|
# Set model config
|
||||||
|
T = 5 # number of frames per sample
|
||||||
|
V = 8 # number of views per sample
|
||||||
|
F = 8 # vae factor to downsize image->latent
|
||||||
|
C = 4
|
||||||
|
H, W = 576, 576
|
||||||
|
n_frames = 21 # number of input and output video frames
|
||||||
|
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
|
||||||
|
n_views_sv3d = 21
|
||||||
|
subsampled_views = np.array(
|
||||||
|
[0, 2, 5, 7, 9, 12, 14, 16, 19]
|
||||||
|
) # subsample (V+1=)9 (uniform) views from 21 SV3D views
|
||||||
|
|
||||||
|
version_dict = {
|
||||||
|
"T": T * V,
|
||||||
|
"H": H,
|
||||||
|
"W": W,
|
||||||
|
"C": C,
|
||||||
|
"f": F,
|
||||||
|
"options": {
|
||||||
|
"discretization": 1,
|
||||||
|
"cfg": 3,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
"sigma_max": 700.0,
|
||||||
|
"rho": 7.0,
|
||||||
|
"guider": 5,
|
||||||
|
"num_steps": num_steps,
|
||||||
|
"force_uc_zero_embeddings": [
|
||||||
|
"cond_frames",
|
||||||
|
"cond_frames_without_noise",
|
||||||
|
"cond_view",
|
||||||
|
"cond_motion",
|
||||||
|
],
|
||||||
|
"additional_guider_kwargs": {
|
||||||
|
"additional_cond_keys": ["cond_view", "cond_motion"]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load SV4D model
|
||||||
|
model, filter = load_model(
|
||||||
|
model_config,
|
||||||
|
device,
|
||||||
|
version_dict["T"],
|
||||||
|
num_steps,
|
||||||
|
)
|
||||||
|
model = initial_model_load(model)
|
||||||
|
|
||||||
|
# -----------sv3d config and model loading----------------
|
||||||
|
# if version == "sv3d_u":
|
||||||
|
sv3d_model_config = "scripts/sampling/configs/sv3d_u.yaml"
|
||||||
|
# elif version == "sv3d_p":
|
||||||
|
# sv3d_model_config = "scripts/sampling/configs/sv3d_p.yaml"
|
||||||
|
# else:
|
||||||
|
# raise ValueError(f"Version {version} does not exist.")
|
||||||
|
|
||||||
|
# Define the repo, local directory and filename
|
||||||
|
repo_id = "stabilityai/sv3d"
|
||||||
|
filename = f"sv3d_u.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors"
|
||||||
|
local_dir = "checkpoints"
|
||||||
|
local_ckpt_path = os.path.join(local_dir, filename)
|
||||||
|
|
||||||
|
# Check if the file already exists
|
||||||
|
if not os.path.exists(local_ckpt_path):
|
||||||
|
# If the file doesn't exist, download it
|
||||||
|
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
|
||||||
|
print("File downloaded. (sv3d)")
|
||||||
|
else:
|
||||||
|
print("File already exists. No need to download. (sv3d)")
|
||||||
|
|
||||||
|
# load sv3d model
|
||||||
|
sv3d_model, filter = load_model(
|
||||||
|
sv3d_model_config,
|
||||||
|
device,
|
||||||
|
21,
|
||||||
|
num_steps,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
sv3d_model = initial_model_load(sv3d_model)
|
||||||
|
# ------------------
|
||||||
|
|
||||||
|
def sample_anchor(
|
||||||
|
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
|
||||||
|
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||||
|
num_steps: int = 20,
|
||||||
|
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
|
||||||
|
fps_id: int = 6,
|
||||||
|
motion_bucket_id: int = 127,
|
||||||
|
cond_aug: float = 1e-5,
|
||||||
|
device: str = "cuda",
|
||||||
|
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
|
||||||
|
azimuths_deg: Optional[List[float]] = None,
|
||||||
|
verbose: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
|
||||||
|
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
|
||||||
|
"""
|
||||||
|
output_folder = os.path.dirname(input_path)
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
|
||||||
|
# Read input video frames i.e. images at view 0
|
||||||
|
print(f"Reading {input_path}")
|
||||||
|
images_v0 = read_video(
|
||||||
|
input_path,
|
||||||
|
n_frames=n_frames,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get camera viewpoints
|
||||||
|
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
||||||
|
elevations_deg = [elevations_deg] * n_views_sv3d
|
||||||
|
assert (
|
||||||
|
len(elevations_deg) == n_views_sv3d
|
||||||
|
), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}"
|
||||||
|
if azimuths_deg is None:
|
||||||
|
azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360
|
||||||
|
assert (
|
||||||
|
len(azimuths_deg) == n_views_sv3d
|
||||||
|
), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}"
|
||||||
|
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
|
||||||
|
azimuths_rad = np.array(
|
||||||
|
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sample multi-view images of the first frame using SV3D i.e. images at time 0
|
||||||
|
sv3d_model.sampler.num_steps = num_steps
|
||||||
|
print("sv3d_model.sampler.num_steps", sv3d_model.sampler.num_steps)
|
||||||
|
images_t0 = sample_sv3d(
|
||||||
|
images_v0[0],
|
||||||
|
n_views_sv3d,
|
||||||
|
num_steps,
|
||||||
|
sv3d_version,
|
||||||
|
fps_id,
|
||||||
|
motion_bucket_id,
|
||||||
|
cond_aug,
|
||||||
|
decoding_t,
|
||||||
|
device,
|
||||||
|
polars_rad,
|
||||||
|
azimuths_rad,
|
||||||
|
verbose,
|
||||||
|
sv3d_model,
|
||||||
|
)
|
||||||
|
images_t0 = torch.roll(images_t0, 1, 0) # move conditioning image to first frame
|
||||||
|
|
||||||
|
sv3d_file = os.path.join(output_folder, "t000.mp4")
|
||||||
|
save_video(sv3d_file, images_t0.unsqueeze(1))
|
||||||
|
|
||||||
|
for emb in model.conditioner.embedders:
|
||||||
|
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
|
||||||
|
emb.en_and_decode_n_samples_a_time = encoding_t
|
||||||
|
model.en_and_decode_n_samples_a_time = decoding_t
|
||||||
|
# Initialize image matrix
|
||||||
|
img_matrix = [[None] * n_views for _ in range(n_frames)]
|
||||||
|
for i, v in enumerate(subsampled_views):
|
||||||
|
img_matrix[0][i] = images_t0[v].unsqueeze(0)
|
||||||
|
for t in range(n_frames):
|
||||||
|
img_matrix[t][0] = images_v0[t]
|
||||||
|
|
||||||
|
# Interleaved sampling for anchor frames
|
||||||
|
t0, v0 = 0, 0
|
||||||
|
frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20]
|
||||||
|
view_indices = np.arange(V) + 1
|
||||||
|
print(f"Sampling anchor frames {frame_indices}")
|
||||||
|
image = img_matrix[t0][v0]
|
||||||
|
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
|
||||||
|
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
|
||||||
|
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||||
|
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||||
|
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||||
|
model.sampler.num_steps = num_steps
|
||||||
|
version_dict["options"]["num_steps"] = num_steps
|
||||||
|
samples = run_img2vid(
|
||||||
|
version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t
|
||||||
|
)
|
||||||
|
samples = samples.view(T, V, 3, H, W)
|
||||||
|
for i, t in enumerate(frame_indices):
|
||||||
|
for j, v in enumerate(view_indices):
|
||||||
|
if img_matrix[t][v] is None:
|
||||||
|
img_matrix[t][v] = samples[i, j][None] * 2 - 1
|
||||||
|
|
||||||
|
# concat video
|
||||||
|
grid_list = []
|
||||||
|
for t in frame_indices:
|
||||||
|
imgs_view = torch.cat(img_matrix[t])
|
||||||
|
grid_list.append(torchvision.utils.make_grid(imgs_view, nrow=3).unsqueeze(0))
|
||||||
|
# save output videos
|
||||||
|
anchor_vis_file = os.path.join(output_folder, "anchor_vis.mp4")
|
||||||
|
save_video(anchor_vis_file, grid_list, fps=3)
|
||||||
|
anchor_file = os.path.join(output_folder, "anchor.mp4")
|
||||||
|
image_list = samples.view(T*V, 3, H, W).unsqueeze(1) * 2 - 1
|
||||||
|
save_video(anchor_file, image_list)
|
||||||
|
|
||||||
|
return sv3d_file, anchor_vis_file, anchor_file
|
||||||
|
|
||||||
|
|
||||||
|
def sample_all(
|
||||||
|
input_path: str = "inputs/test_video1.mp4", # Can either be video file or folder with image files
|
||||||
|
sv3d_path: str = "outputs/sv4d/000000_t000.mp4",
|
||||||
|
anchor_path: str = "outputs/sv4d/000000_anchor.mp4",
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
num_steps: int = 20,
|
||||||
|
device: str = "cuda",
|
||||||
|
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
|
||||||
|
azimuths_deg: Optional[List[float]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
|
||||||
|
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
|
||||||
|
"""
|
||||||
|
output_folder = os.path.dirname(input_path)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
|
||||||
|
# Read input video frames i.e. images at view 0
|
||||||
|
print(f"Reading {input_path}")
|
||||||
|
images_v0 = read_video(
|
||||||
|
input_path,
|
||||||
|
n_frames=n_frames,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
images_t0 = read_video(
|
||||||
|
sv3d_path,
|
||||||
|
n_frames=n_views_sv3d,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get camera viewpoints
|
||||||
|
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
||||||
|
elevations_deg = [elevations_deg] * n_views_sv3d
|
||||||
|
assert (
|
||||||
|
len(elevations_deg) == n_views_sv3d
|
||||||
|
), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}"
|
||||||
|
if azimuths_deg is None:
|
||||||
|
azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360
|
||||||
|
assert (
|
||||||
|
len(azimuths_deg) == n_views_sv3d
|
||||||
|
), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}"
|
||||||
|
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
|
||||||
|
azimuths_rad = np.array(
|
||||||
|
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize image matrix
|
||||||
|
img_matrix = [[None] * n_views for _ in range(n_frames)]
|
||||||
|
for i, v in enumerate(subsampled_views):
|
||||||
|
img_matrix[0][i] = images_t0[v]
|
||||||
|
for t in range(n_frames):
|
||||||
|
img_matrix[t][0] = images_v0[t]
|
||||||
|
|
||||||
|
# load interleaved sampling for anchor frames
|
||||||
|
t0, v0 = 0, 0
|
||||||
|
frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20]
|
||||||
|
view_indices = np.arange(V) + 1
|
||||||
|
|
||||||
|
anchor_frames = read_video(
|
||||||
|
anchor_path,
|
||||||
|
n_frames=T * V,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
anchor_frames = torch.cat(anchor_frames).view(T, V, 3, H, W)
|
||||||
|
for i, t in enumerate(frame_indices):
|
||||||
|
for j, v in enumerate(view_indices):
|
||||||
|
if img_matrix[t][v] is None:
|
||||||
|
img_matrix[t][v] = anchor_frames[i, j][None]
|
||||||
|
|
||||||
|
# Dense sampling for the rest
|
||||||
|
print(f"Sampling dense frames:")
|
||||||
|
for t0 in np.arange(0, n_frames - 1, T - 1): # [0, 4, 8, 12, 16]
|
||||||
|
frame_indices = t0 + np.arange(T)
|
||||||
|
print(f"Sampling dense frames {frame_indices}")
|
||||||
|
latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda")
|
||||||
|
|
||||||
|
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||||
|
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||||
|
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||||
|
|
||||||
|
# alternate between forward and backward conditioning
|
||||||
|
forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices = prepare_inputs(
|
||||||
|
frame_indices,
|
||||||
|
img_matrix,
|
||||||
|
v0,
|
||||||
|
view_indices,
|
||||||
|
model,
|
||||||
|
version_dict,
|
||||||
|
seed,
|
||||||
|
polars,
|
||||||
|
azims
|
||||||
|
)
|
||||||
|
|
||||||
|
for step in range(num_steps):
|
||||||
|
if step % 2 == 1:
|
||||||
|
c, uc, additional_model_inputs, sampler = forward_inputs
|
||||||
|
frame_indices = forward_frame_indices
|
||||||
|
else:
|
||||||
|
c, uc, additional_model_inputs, sampler = backward_inputs
|
||||||
|
frame_indices = backward_frame_indices
|
||||||
|
noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1)
|
||||||
|
|
||||||
|
samples = do_sample_per_step(
|
||||||
|
model,
|
||||||
|
sampler,
|
||||||
|
noisy_latents,
|
||||||
|
c,
|
||||||
|
uc,
|
||||||
|
step,
|
||||||
|
additional_model_inputs,
|
||||||
|
)
|
||||||
|
samples = samples.view(T, V, C, H // F, W // F)
|
||||||
|
for i, t in enumerate(frame_indices):
|
||||||
|
for j, v in enumerate(view_indices):
|
||||||
|
latent_matrix[t, v] = samples[i, j]
|
||||||
|
|
||||||
|
img_matrix = decode_latents(model, latent_matrix, img_matrix, frame_indices, view_indices, T)
|
||||||
|
|
||||||
|
|
||||||
|
# concat video
|
||||||
|
grid_list = []
|
||||||
|
for t in range(n_frames):
|
||||||
|
imgs_view = torch.cat(img_matrix[t])
|
||||||
|
grid_list.append(torchvision.utils.make_grid(imgs_view, nrow=3).unsqueeze(0))
|
||||||
|
# save output videos
|
||||||
|
vid_file = os.path.join(output_folder, "sv4d_final.mp4")
|
||||||
|
save_video(vid_file, grid_list)
|
||||||
|
|
||||||
|
return vid_file, seed
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
gr.Markdown(
|
||||||
|
"""# Demo for SV4D from Stability AI ([model](https://huggingface.co/stabilityai/sv4d), [news](https://stability.ai/news/stable-video-4d))
|
||||||
|
#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/sv4d/blob/main/LICENSE.md)): generate 8 novel view videos from a single-view video (with white background).
|
||||||
|
#### It takes ~45s to generate anchor frames and another ~160s to generate full results (21 frames).
|
||||||
|
#### Hints for improving performance:
|
||||||
|
- Use a white background;
|
||||||
|
- Make the object in the center of the image;
|
||||||
|
- The SV4D process the first 21 frames of the uploaded video. Gradio provides a nice option of trimming the uploaded video if needed.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
input_video = gr.Video(label="Upload your video")
|
||||||
|
generate_btn = gr.Button("Step 1: generate 8 novel view videos (5 anchor frames each)")
|
||||||
|
interpolate_btn = gr.Button("Step 2: Extend novel view videos to 21 frames")
|
||||||
|
with gr.Column():
|
||||||
|
anchor_video = gr.Video(label="SV4D outputs (anchor frames)")
|
||||||
|
sv3d_video = gr.Video(label="SV3D outputs", interactive=False)
|
||||||
|
with gr.Column():
|
||||||
|
sv4d_interpolated_video = gr.Video(label="SV4D outputs (21 frames)")
|
||||||
|
|
||||||
|
with gr.Accordion("Advanced options", open=False):
|
||||||
|
seed = gr.Slider(
|
||||||
|
label="Seed",
|
||||||
|
value=23,
|
||||||
|
# randomize=True,
|
||||||
|
minimum=0,
|
||||||
|
maximum=100,
|
||||||
|
step=1,
|
||||||
|
)
|
||||||
|
encoding_t = gr.Slider(
|
||||||
|
label="Encode n frames at a time",
|
||||||
|
info="Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.",
|
||||||
|
value=8,
|
||||||
|
minimum=1,
|
||||||
|
maximum=40,
|
||||||
|
)
|
||||||
|
decoding_t = gr.Slider(
|
||||||
|
label="Decode n frames at a time",
|
||||||
|
info="Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.",
|
||||||
|
value=4,
|
||||||
|
minimum=1,
|
||||||
|
maximum=14,
|
||||||
|
)
|
||||||
|
denoising_steps = gr.Slider(
|
||||||
|
label="Number of denoising steps",
|
||||||
|
info="Increase will improve the performance but needs more time.",
|
||||||
|
value=20,
|
||||||
|
minimum=10,
|
||||||
|
maximum=50,
|
||||||
|
step=1,
|
||||||
|
)
|
||||||
|
remove_bg = gr.Checkbox(
|
||||||
|
label="Remove background",
|
||||||
|
info="We use rembg. Users can check the alternative way: SAM2 (https://github.com/facebookresearch/segment-anything-2)",
|
||||||
|
)
|
||||||
|
|
||||||
|
input_video.upload(fn=preprocess_video, inputs=[input_video, remove_bg], outputs=input_video, queue=False)
|
||||||
|
|
||||||
|
with gr.Row(visible=False):
|
||||||
|
anchor_frames = gr.Video()
|
||||||
|
|
||||||
|
generate_btn.click(
|
||||||
|
fn=sample_anchor,
|
||||||
|
inputs=[input_video, seed, encoding_t, decoding_t, denoising_steps],
|
||||||
|
outputs=[sv3d_video, anchor_video, anchor_frames],
|
||||||
|
api_name="SV4D output (5 frames)",
|
||||||
|
)
|
||||||
|
|
||||||
|
interpolate_btn.click(
|
||||||
|
fn=sample_all,
|
||||||
|
inputs=[input_video, sv3d_video, anchor_frames, seed, denoising_steps],
|
||||||
|
outputs=[sv4d_interpolated_video, seed],
|
||||||
|
api_name="SV4D interpolation (21 frames)",
|
||||||
|
)
|
||||||
|
|
||||||
|
examples = gr.Examples(
|
||||||
|
fn=preprocess_video,
|
||||||
|
examples=[
|
||||||
|
"./assets/sv4d_videos/test_video1.mp4",
|
||||||
|
"./assets/sv4d_videos/test_video2.mp4",
|
||||||
|
"./assets/sv4d_videos/green_robot.mp4",
|
||||||
|
"./assets/sv4d_videos/dolphin.mp4",
|
||||||
|
"./assets/sv4d_videos/lucia_v000.mp4",
|
||||||
|
"./assets/sv4d_videos/snowboard_v000.mp4",
|
||||||
|
"./assets/sv4d_videos/stroller_v000.mp4",
|
||||||
|
"./assets/sv4d_videos/human5.mp4",
|
||||||
|
"./assets/sv4d_videos/bunnyman.mp4",
|
||||||
|
"./assets/sv4d_videos/hiphop_parrot.mp4",
|
||||||
|
"./assets/sv4d_videos/guppie_v0.mp4",
|
||||||
|
"./assets/sv4d_videos/wave_hello.mp4",
|
||||||
|
"./assets/sv4d_videos/pistol_v0.mp4",
|
||||||
|
"./assets/sv4d_videos/human7.mp4",
|
||||||
|
"./assets/sv4d_videos/monkey.mp4",
|
||||||
|
"./assets/sv4d_videos/train_v0.mp4",
|
||||||
|
],
|
||||||
|
inputs=[input_video],
|
||||||
|
run_on_click=True,
|
||||||
|
outputs=[input_video],
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo.queue(max_size=20)
|
||||||
|
demo.launch(share=True)
|
||||||
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
|
|
||||||
from scripts.demo.streamlit_helpers import *
|
from scripts.demo.streamlit_helpers import *
|
||||||
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
|
||||||
|
|
||||||
SAVE_PATH = "outputs/demo/txt2img/"
|
SAVE_PATH = "outputs/demo/txt2img/"
|
||||||
|
|
||||||
@@ -34,7 +34,16 @@ SD_XL_BASE_RATIOS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
VERSION2SPECS = {
|
VERSION2SPECS = {
|
||||||
"SD-XL base": {
|
"SDXL-base-1.0": {
|
||||||
|
"H": 1024,
|
||||||
|
"W": 1024,
|
||||||
|
"C": 4,
|
||||||
|
"f": 8,
|
||||||
|
"is_legacy": False,
|
||||||
|
"config": "configs/inference/sd_xl_base.yaml",
|
||||||
|
"ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
|
||||||
|
},
|
||||||
|
"SDXL-base-0.9": {
|
||||||
"H": 1024,
|
"H": 1024,
|
||||||
"W": 1024,
|
"W": 1024,
|
||||||
"C": 4,
|
"C": 4,
|
||||||
@@ -42,28 +51,8 @@ VERSION2SPECS = {
|
|||||||
"is_legacy": False,
|
"is_legacy": False,
|
||||||
"config": "configs/inference/sd_xl_base.yaml",
|
"config": "configs/inference/sd_xl_base.yaml",
|
||||||
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
|
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
|
||||||
"is_guided": True,
|
|
||||||
},
|
},
|
||||||
"sd-2.1": {
|
"SDXL-refiner-0.9": {
|
||||||
"H": 512,
|
|
||||||
"W": 512,
|
|
||||||
"C": 4,
|
|
||||||
"f": 8,
|
|
||||||
"is_legacy": True,
|
|
||||||
"config": "configs/inference/sd_2_1.yaml",
|
|
||||||
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
|
|
||||||
"is_guided": True,
|
|
||||||
},
|
|
||||||
"sd-2.1-768": {
|
|
||||||
"H": 768,
|
|
||||||
"W": 768,
|
|
||||||
"C": 4,
|
|
||||||
"f": 8,
|
|
||||||
"is_legacy": True,
|
|
||||||
"config": "configs/inference/sd_2_1_768.yaml",
|
|
||||||
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
|
|
||||||
},
|
|
||||||
"SDXL-Refiner": {
|
|
||||||
"H": 1024,
|
"H": 1024,
|
||||||
"W": 1024,
|
"W": 1024,
|
||||||
"C": 4,
|
"C": 4,
|
||||||
@@ -71,7 +60,15 @@ VERSION2SPECS = {
|
|||||||
"is_legacy": True,
|
"is_legacy": True,
|
||||||
"config": "configs/inference/sd_xl_refiner.yaml",
|
"config": "configs/inference/sd_xl_refiner.yaml",
|
||||||
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
|
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
|
||||||
"is_guided": True,
|
},
|
||||||
|
"SDXL-refiner-1.0": {
|
||||||
|
"H": 1024,
|
||||||
|
"W": 1024,
|
||||||
|
"C": 4,
|
||||||
|
"f": 8,
|
||||||
|
"is_legacy": True,
|
||||||
|
"config": "configs/inference/sd_xl_refiner.yaml",
|
||||||
|
"ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,18 +92,19 @@ def load_img(display=True, key=None, device="cuda"):
|
|||||||
|
|
||||||
|
|
||||||
def run_txt2img(
|
def run_txt2img(
|
||||||
state, version, version_dict, is_legacy=False, return_latents=False, filter=None
|
state,
|
||||||
|
version,
|
||||||
|
version_dict,
|
||||||
|
is_legacy=False,
|
||||||
|
return_latents=False,
|
||||||
|
filter=None,
|
||||||
|
stage2strength=None,
|
||||||
):
|
):
|
||||||
if version == "SD-XL base":
|
if version.startswith("SDXL-base"):
|
||||||
ratio = st.sidebar.selectbox("Ratio:", list(SD_XL_BASE_RATIOS.keys()), 10)
|
W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
|
||||||
W, H = SD_XL_BASE_RATIOS[ratio]
|
|
||||||
else:
|
else:
|
||||||
H = st.sidebar.number_input(
|
H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048)
|
||||||
"H", value=version_dict["H"], min_value=64, max_value=2048
|
W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048)
|
||||||
)
|
|
||||||
W = st.sidebar.number_input(
|
|
||||||
"W", value=version_dict["W"], min_value=64, max_value=2048
|
|
||||||
)
|
|
||||||
C = version_dict["C"]
|
C = version_dict["C"]
|
||||||
F = version_dict["f"]
|
F = version_dict["f"]
|
||||||
|
|
||||||
@@ -122,10 +120,7 @@ def run_txt2img(
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
)
|
)
|
||||||
num_rows, num_cols, sampler = init_sampling(
|
sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)
|
||||||
use_identity_guider=not version_dict["is_guided"]
|
|
||||||
)
|
|
||||||
|
|
||||||
num_samples = num_rows * num_cols
|
num_samples = num_rows * num_cols
|
||||||
|
|
||||||
if st.button("Sample"):
|
if st.button("Sample"):
|
||||||
@@ -147,7 +142,12 @@ def run_txt2img(
|
|||||||
|
|
||||||
|
|
||||||
def run_img2img(
|
def run_img2img(
|
||||||
state, version_dict, is_legacy=False, return_latents=False, filter=None
|
state,
|
||||||
|
version_dict,
|
||||||
|
is_legacy=False,
|
||||||
|
return_latents=False,
|
||||||
|
filter=None,
|
||||||
|
stage2strength=None,
|
||||||
):
|
):
|
||||||
img = load_img()
|
img = load_img()
|
||||||
if img is None:
|
if img is None:
|
||||||
@@ -163,13 +163,15 @@ def run_img2img(
|
|||||||
value_dict = init_embedder_options(
|
value_dict = init_embedder_options(
|
||||||
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
|
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
|
||||||
init_dict,
|
init_dict,
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
)
|
)
|
||||||
strength = st.number_input(
|
strength = st.number_input(
|
||||||
"**Img2Img Strength**", value=0.5, min_value=0.0, max_value=1.0
|
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
|
||||||
)
|
)
|
||||||
num_rows, num_cols, sampler = init_sampling(
|
sampler, num_rows, num_cols = init_sampling(
|
||||||
img2img_strength=strength,
|
img2img_strength=strength,
|
||||||
use_identity_guider=not version_dict["is_guided"],
|
stage2strength=stage2strength,
|
||||||
)
|
)
|
||||||
num_samples = num_rows * num_cols
|
num_samples = num_rows * num_cols
|
||||||
|
|
||||||
@@ -195,6 +197,7 @@ def apply_refiner(
|
|||||||
prompt,
|
prompt,
|
||||||
negative_prompt,
|
negative_prompt,
|
||||||
filter=None,
|
filter=None,
|
||||||
|
finish_denoising=False,
|
||||||
):
|
):
|
||||||
init_dict = {
|
init_dict = {
|
||||||
"orig_width": input.shape[3] * 8,
|
"orig_width": input.shape[3] * 8,
|
||||||
@@ -222,6 +225,7 @@ def apply_refiner(
|
|||||||
num_samples,
|
num_samples,
|
||||||
skip_encode=True,
|
skip_encode=True,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
|
add_noise=not finish_denoising,
|
||||||
)
|
)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
@@ -231,26 +235,30 @@ if __name__ == "__main__":
|
|||||||
st.title("Stable Diffusion")
|
st.title("Stable Diffusion")
|
||||||
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||||
version_dict = VERSION2SPECS[version]
|
version_dict = VERSION2SPECS[version]
|
||||||
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
|
if st.checkbox("Load Model"):
|
||||||
|
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
|
||||||
|
else:
|
||||||
|
mode = "skip"
|
||||||
st.write("__________________________")
|
st.write("__________________________")
|
||||||
|
|
||||||
if version == "SD-XL base":
|
set_lowvram_mode(st.checkbox("Low vram mode", True))
|
||||||
add_pipeline = st.checkbox("Load SDXL-Refiner?", False)
|
|
||||||
|
if version.startswith("SDXL-base"):
|
||||||
|
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
|
||||||
st.write("__________________________")
|
st.write("__________________________")
|
||||||
else:
|
else:
|
||||||
add_pipeline = False
|
add_pipeline = False
|
||||||
|
|
||||||
filter = DeepFloydDataFiltering(verbose=False)
|
|
||||||
|
|
||||||
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
|
|
||||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
|
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
|
||||||
|
|
||||||
state = init_st(version_dict)
|
if mode != "skip":
|
||||||
if state["msg"]:
|
state = init_st(version_dict, load_filter=True)
|
||||||
st.info(state["msg"])
|
if state["msg"]:
|
||||||
model = state["model"]
|
st.info(state["msg"])
|
||||||
|
model = state["model"]
|
||||||
|
|
||||||
is_legacy = version_dict["is_legacy"]
|
is_legacy = version_dict["is_legacy"]
|
||||||
|
|
||||||
@@ -263,30 +271,34 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
negative_prompt = "" # which is unused
|
negative_prompt = "" # which is unused
|
||||||
|
|
||||||
|
stage2strength = None
|
||||||
|
finish_denoising = False
|
||||||
|
|
||||||
if add_pipeline:
|
if add_pipeline:
|
||||||
st.write("__________________________")
|
st.write("__________________________")
|
||||||
|
version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"])
|
||||||
version2 = "SDXL-Refiner"
|
|
||||||
st.warning(
|
st.warning(
|
||||||
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
|
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
|
||||||
)
|
)
|
||||||
st.write("**Refiner Options:**")
|
st.write("**Refiner Options:**")
|
||||||
|
|
||||||
version_dict2 = VERSION2SPECS[version2]
|
version_dict2 = VERSION2SPECS[version2]
|
||||||
state2 = init_st(version_dict2)
|
state2 = init_st(version_dict2, load_filter=False)
|
||||||
st.info(state2["msg"])
|
st.info(state2["msg"])
|
||||||
|
|
||||||
stage2strength = st.number_input(
|
stage2strength = st.number_input(
|
||||||
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
|
"**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
sampler2 = init_sampling(
|
sampler2, *_ = init_sampling(
|
||||||
key=2,
|
key=2,
|
||||||
img2img_strength=stage2strength,
|
img2img_strength=stage2strength,
|
||||||
use_identity_guider=not version_dict["is_guided"],
|
specify_num_samples=False,
|
||||||
get_num_samples=False,
|
|
||||||
)
|
)
|
||||||
st.write("__________________________")
|
st.write("__________________________")
|
||||||
|
finish_denoising = st.checkbox("Finish denoising with refiner.", True)
|
||||||
|
if not finish_denoising:
|
||||||
|
stage2strength = None
|
||||||
|
|
||||||
if mode == "txt2img":
|
if mode == "txt2img":
|
||||||
out = run_txt2img(
|
out = run_txt2img(
|
||||||
@@ -295,7 +307,8 @@ if __name__ == "__main__":
|
|||||||
version_dict,
|
version_dict,
|
||||||
is_legacy=is_legacy,
|
is_legacy=is_legacy,
|
||||||
return_latents=add_pipeline,
|
return_latents=add_pipeline,
|
||||||
filter=filter,
|
filter=state.get("filter"),
|
||||||
|
stage2strength=stage2strength,
|
||||||
)
|
)
|
||||||
elif mode == "img2img":
|
elif mode == "img2img":
|
||||||
out = run_img2img(
|
out = run_img2img(
|
||||||
@@ -303,16 +316,20 @@ if __name__ == "__main__":
|
|||||||
version_dict,
|
version_dict,
|
||||||
is_legacy=is_legacy,
|
is_legacy=is_legacy,
|
||||||
return_latents=add_pipeline,
|
return_latents=add_pipeline,
|
||||||
filter=filter,
|
filter=state.get("filter"),
|
||||||
|
stage2strength=stage2strength,
|
||||||
)
|
)
|
||||||
|
elif mode == "skip":
|
||||||
|
out = None
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown mode {mode}")
|
raise ValueError(f"unknown mode {mode}")
|
||||||
if isinstance(out, (tuple, list)):
|
if isinstance(out, (tuple, list)):
|
||||||
samples, samples_z = out
|
samples, samples_z = out
|
||||||
else:
|
else:
|
||||||
samples = out
|
samples = out
|
||||||
|
samples_z = None
|
||||||
|
|
||||||
if add_pipeline:
|
if add_pipeline and samples_z is not None:
|
||||||
st.write("**Running Refinement Stage**")
|
st.write("**Running Refinement Stage**")
|
||||||
samples = apply_refiner(
|
samples = apply_refiner(
|
||||||
samples_z,
|
samples_z,
|
||||||
@@ -321,7 +338,8 @@ if __name__ == "__main__":
|
|||||||
samples_z.shape[0],
|
samples_z.shape[0],
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt if is_legacy else "",
|
negative_prompt=negative_prompt if is_legacy else "",
|
||||||
filter=filter,
|
filter=state.get("filter"),
|
||||||
|
finish_denoising=finish_denoising,
|
||||||
)
|
)
|
||||||
|
|
||||||
if save_locally and samples is not None:
|
if save_locally and samples is not None:
|
||||||
|
|||||||
@@ -1,78 +1,48 @@
|
|||||||
import os
|
import copy
|
||||||
from typing import Union, List
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import imageio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
import torch.nn as nn
|
||||||
|
import torchvision.transforms as TT
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from imwatermark import WatermarkEncoder
|
from imwatermark import WatermarkEncoder
|
||||||
from omegaconf import OmegaConf, ListConfig
|
from omegaconf import ListConfig, OmegaConf
|
||||||
from torch import autocast
|
from PIL import Image
|
||||||
from torchvision import transforms
|
|
||||||
from torchvision.utils import make_grid
|
|
||||||
from safetensors.torch import load_file as load_safetensors
|
from safetensors.torch import load_file as load_safetensors
|
||||||
|
from scripts.demo.discretization import (
|
||||||
|
Img2ImgDiscretizationWrapper,
|
||||||
|
Txt2NoisyDiscretizationWrapper,
|
||||||
|
)
|
||||||
|
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||||
|
from sgm.inference.helpers import embed_watermark
|
||||||
|
from sgm.modules.diffusionmodules.guiders import (
|
||||||
|
LinearPredictionGuider,
|
||||||
|
TrianglePredictionGuider,
|
||||||
|
VanillaCFG,
|
||||||
|
)
|
||||||
from sgm.modules.diffusionmodules.sampling import (
|
from sgm.modules.diffusionmodules.sampling import (
|
||||||
|
DPMPP2MSampler,
|
||||||
|
DPMPP2SAncestralSampler,
|
||||||
|
EulerAncestralSampler,
|
||||||
EulerEDMSampler,
|
EulerEDMSampler,
|
||||||
HeunEDMSampler,
|
HeunEDMSampler,
|
||||||
EulerAncestralSampler,
|
|
||||||
DPMPP2SAncestralSampler,
|
|
||||||
DPMPP2MSampler,
|
|
||||||
LinearMultistepSampler,
|
LinearMultistepSampler,
|
||||||
)
|
)
|
||||||
from sgm.util import append_dims
|
from sgm.util import append_dims, default, instantiate_from_config
|
||||||
from sgm.util import instantiate_from_config
|
from torch import autocast
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.utils import make_grid, save_image
|
||||||
class WatermarkEmbedder:
|
|
||||||
def __init__(self, watermark):
|
|
||||||
self.watermark = watermark
|
|
||||||
self.num_bits = len(WATERMARK_BITS)
|
|
||||||
self.encoder = WatermarkEncoder()
|
|
||||||
self.encoder.set_watermark("bits", self.watermark)
|
|
||||||
|
|
||||||
def __call__(self, image: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Adds a predefined watermark to the input image
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image: ([N,] B, C, H, W) in range [0, 1]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
same as input but watermarked
|
|
||||||
"""
|
|
||||||
# watermarking libary expects input as cv2 format
|
|
||||||
squeeze = len(image.shape) == 4
|
|
||||||
if squeeze:
|
|
||||||
image = image[None, ...]
|
|
||||||
n = image.shape[0]
|
|
||||||
image_np = rearrange(
|
|
||||||
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
|
||||||
).numpy()
|
|
||||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
|
||||||
for k in range(image_np.shape[0]):
|
|
||||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
|
||||||
image = torch.from_numpy(
|
|
||||||
rearrange(image_np, "(n b) h w c -> n b c h w", n=n)
|
|
||||||
).to(image.device)
|
|
||||||
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
|
||||||
if squeeze:
|
|
||||||
image = image[0]
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
# A fixed 48-bit message that was choosen at random
|
|
||||||
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
|
||||||
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
|
||||||
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
|
||||||
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
|
||||||
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache_resource()
|
@st.cache_resource()
|
||||||
def init_st(version_dict, load_ckpt=True):
|
def init_st(version_dict, load_ckpt=True, load_filter=True):
|
||||||
state = dict()
|
state = dict()
|
||||||
if not "model" in state:
|
if not "model" in state:
|
||||||
config = version_dict["config"]
|
config = version_dict["config"]
|
||||||
@@ -85,9 +55,39 @@ def init_st(version_dict, load_ckpt=True):
|
|||||||
state["model"] = model
|
state["model"] = model
|
||||||
state["ckpt"] = ckpt if load_ckpt else None
|
state["ckpt"] = ckpt if load_ckpt else None
|
||||||
state["config"] = config
|
state["config"] = config
|
||||||
|
if load_filter:
|
||||||
|
state["filter"] = DeepFloydDataFiltering(verbose=False)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model):
|
||||||
|
model.cuda()
|
||||||
|
|
||||||
|
|
||||||
|
lowvram_mode = False
|
||||||
|
|
||||||
|
|
||||||
|
def set_lowvram_mode(mode):
|
||||||
|
global lowvram_mode
|
||||||
|
lowvram_mode = mode
|
||||||
|
|
||||||
|
|
||||||
|
def initial_model_load(model):
|
||||||
|
global lowvram_mode
|
||||||
|
if lowvram_mode:
|
||||||
|
model.model.half()
|
||||||
|
else:
|
||||||
|
model.cuda()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def unload_model(model):
|
||||||
|
global lowvram_mode
|
||||||
|
if lowvram_mode:
|
||||||
|
model.cpu()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt=None, verbose=True):
|
def load_model_from_config(config, ckpt=None, verbose=True):
|
||||||
model = instantiate_from_config(config.model)
|
model = instantiate_from_config(config.model)
|
||||||
|
|
||||||
@@ -118,7 +118,7 @@ def load_model_from_config(config, ckpt=None, verbose=True):
|
|||||||
else:
|
else:
|
||||||
msg = None
|
msg = None
|
||||||
|
|
||||||
model.cuda()
|
model = initial_model_load(model)
|
||||||
model.eval()
|
model.eval()
|
||||||
return model, msg
|
return model, msg
|
||||||
|
|
||||||
@@ -134,11 +134,12 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
|||||||
for key in keys:
|
for key in keys:
|
||||||
if key == "txt":
|
if key == "txt":
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
prompt = st.text_input(
|
prompt = "A professional photograph of an astronaut riding a pig"
|
||||||
"Prompt", "A professional photograph of an astronaut riding a pig"
|
|
||||||
)
|
|
||||||
if negative_prompt is None:
|
if negative_prompt is None:
|
||||||
negative_prompt = st.text_input("Negative prompt", "")
|
negative_prompt = ""
|
||||||
|
|
||||||
|
prompt = st.text_input("Prompt", prompt)
|
||||||
|
negative_prompt = st.text_input("Negative prompt", negative_prompt)
|
||||||
|
|
||||||
value_dict["prompt"] = prompt
|
value_dict["prompt"] = prompt
|
||||||
value_dict["negative_prompt"] = negative_prompt
|
value_dict["negative_prompt"] = negative_prompt
|
||||||
@@ -170,19 +171,30 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
|||||||
value_dict["negative_aesthetic_score"] = 2.5
|
value_dict["negative_aesthetic_score"] = 2.5
|
||||||
|
|
||||||
if key == "target_size_as_tuple":
|
if key == "target_size_as_tuple":
|
||||||
target_width = st.number_input(
|
value_dict["target_width"] = init_dict["target_width"]
|
||||||
"target_width",
|
value_dict["target_height"] = init_dict["target_height"]
|
||||||
value=init_dict["target_width"],
|
|
||||||
min_value=16,
|
|
||||||
)
|
|
||||||
target_height = st.number_input(
|
|
||||||
"target_height",
|
|
||||||
value=init_dict["target_height"],
|
|
||||||
min_value=16,
|
|
||||||
)
|
|
||||||
|
|
||||||
value_dict["target_width"] = target_width
|
if key in ["fps_id", "fps"]:
|
||||||
value_dict["target_height"] = target_height
|
fps = st.number_input("fps", value=6, min_value=1)
|
||||||
|
|
||||||
|
value_dict["fps"] = fps
|
||||||
|
value_dict["fps_id"] = fps - 1
|
||||||
|
|
||||||
|
if key == "motion_bucket_id":
|
||||||
|
mb_id = st.number_input("motion bucket id", 0, 511, value=127)
|
||||||
|
value_dict["motion_bucket_id"] = mb_id
|
||||||
|
|
||||||
|
if key == "pool_image":
|
||||||
|
st.text("Image for pool conditioning")
|
||||||
|
image = load_img(
|
||||||
|
key="pool_image_input",
|
||||||
|
size=224,
|
||||||
|
center_crop=True,
|
||||||
|
)
|
||||||
|
if image is None:
|
||||||
|
st.info("Need an image here")
|
||||||
|
image = torch.zeros(1, 3, 224, 224)
|
||||||
|
value_dict["pool_image"] = image
|
||||||
|
|
||||||
return value_dict
|
return value_dict
|
||||||
|
|
||||||
@@ -190,7 +202,7 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
|||||||
def perform_save_locally(save_path, samples):
|
def perform_save_locally(save_path, samples):
|
||||||
os.makedirs(os.path.join(save_path), exist_ok=True)
|
os.makedirs(os.path.join(save_path), exist_ok=True)
|
||||||
base_count = len(os.listdir(os.path.join(save_path)))
|
base_count = len(os.listdir(os.path.join(save_path)))
|
||||||
samples = embed_watemark(samples)
|
samples = embed_watermark(samples)
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
||||||
Image.fromarray(sample.astype(np.uint8)).save(
|
Image.fromarray(sample.astype(np.uint8)).save(
|
||||||
@@ -209,65 +221,82 @@ def init_save_locally(_dir, init_value: bool = False):
|
|||||||
return save_locally, save_path
|
return save_locally, save_path
|
||||||
|
|
||||||
|
|
||||||
class Img2ImgDiscretizationWrapper:
|
def get_guider(options, key):
|
||||||
"""
|
|
||||||
wraps a discretizer, and prunes the sigmas
|
|
||||||
params:
|
|
||||||
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, discretization, strength: float = 1.0):
|
|
||||||
self.discretization = discretization
|
|
||||||
self.strength = strength
|
|
||||||
assert 0.0 <= self.strength <= 1.0
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
# sigmas start large first, and decrease then
|
|
||||||
sigmas = self.discretization(*args, **kwargs)
|
|
||||||
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
|
||||||
sigmas = torch.flip(sigmas, (0,))
|
|
||||||
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
|
||||||
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
|
||||||
sigmas = torch.flip(sigmas, (0,))
|
|
||||||
print(f"sigmas after pruning: ", sigmas)
|
|
||||||
return sigmas
|
|
||||||
|
|
||||||
|
|
||||||
def get_guider(key):
|
|
||||||
guider = st.sidebar.selectbox(
|
guider = st.sidebar.selectbox(
|
||||||
f"Discretization #{key}",
|
f"Discretization #{key}",
|
||||||
[
|
[
|
||||||
"VanillaCFG",
|
"VanillaCFG",
|
||||||
"IdentityGuider",
|
"IdentityGuider",
|
||||||
|
"LinearPredictionGuider",
|
||||||
|
"TrianglePredictionGuider",
|
||||||
],
|
],
|
||||||
|
options.get("guider", 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
additional_guider_kwargs = options.pop("additional_guider_kwargs", {})
|
||||||
|
|
||||||
if guider == "IdentityGuider":
|
if guider == "IdentityGuider":
|
||||||
guider_config = {
|
guider_config = {
|
||||||
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
||||||
}
|
}
|
||||||
elif guider == "VanillaCFG":
|
elif guider == "VanillaCFG":
|
||||||
scale = st.number_input(
|
scale = st.number_input(
|
||||||
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
|
f"cfg-scale #{key}",
|
||||||
|
value=options.get("cfg", 5.0),
|
||||||
|
min_value=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
thresholder = st.sidebar.selectbox(
|
|
||||||
f"Thresholder #{key}",
|
|
||||||
[
|
|
||||||
"None",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
if thresholder == "None":
|
|
||||||
dyn_thresh_config = {
|
|
||||||
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
guider_config = {
|
guider_config = {
|
||||||
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
||||||
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
|
"params": {
|
||||||
|
"scale": scale,
|
||||||
|
**additional_guider_kwargs,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
elif guider == "LinearPredictionGuider":
|
||||||
|
max_scale = st.number_input(
|
||||||
|
f"max-cfg-scale #{key}",
|
||||||
|
value=options.get("cfg", 1.5),
|
||||||
|
min_value=1.0,
|
||||||
|
)
|
||||||
|
min_scale = st.sidebar.number_input(
|
||||||
|
f"min guidance scale",
|
||||||
|
value=options.get("min_cfg", 1.0),
|
||||||
|
min_value=1.0,
|
||||||
|
max_value=10.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
guider_config = {
|
||||||
|
"target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider",
|
||||||
|
"params": {
|
||||||
|
"max_scale": max_scale,
|
||||||
|
"min_scale": min_scale,
|
||||||
|
"num_frames": options["num_frames"],
|
||||||
|
**additional_guider_kwargs,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
elif guider == "TrianglePredictionGuider":
|
||||||
|
max_scale = st.number_input(
|
||||||
|
f"max-cfg-scale #{key}",
|
||||||
|
value=options.get("cfg", 2.5),
|
||||||
|
min_value=1.0,
|
||||||
|
max_value=10.0,
|
||||||
|
)
|
||||||
|
min_scale = st.sidebar.number_input(
|
||||||
|
f"min guidance scale",
|
||||||
|
value=options.get("min_cfg", 1.0),
|
||||||
|
min_value=1.0,
|
||||||
|
max_value=10.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
guider_config = {
|
||||||
|
"target": "sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider",
|
||||||
|
"params": {
|
||||||
|
"max_scale": max_scale,
|
||||||
|
"min_scale": min_scale,
|
||||||
|
"num_frames": options["num_frames"],
|
||||||
|
**additional_guider_kwargs,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -275,16 +304,22 @@ def get_guider(key):
|
|||||||
|
|
||||||
|
|
||||||
def init_sampling(
|
def init_sampling(
|
||||||
key=1, img2img_strength=1.0, use_identity_guider=False, get_num_samples=True
|
key=1,
|
||||||
|
img2img_strength: Optional[float] = None,
|
||||||
|
specify_num_samples: bool = True,
|
||||||
|
stage2strength: Optional[float] = None,
|
||||||
|
options: Optional[Dict[str, int]] = None,
|
||||||
):
|
):
|
||||||
if get_num_samples:
|
options = {} if options is None else options
|
||||||
num_rows = 1
|
|
||||||
|
num_rows, num_cols = 1, 1
|
||||||
|
if specify_num_samples:
|
||||||
num_cols = st.number_input(
|
num_cols = st.number_input(
|
||||||
f"num cols #{key}", value=2, min_value=1, max_value=10
|
f"num cols #{key}", value=num_cols, min_value=1, max_value=10
|
||||||
)
|
)
|
||||||
|
|
||||||
steps = st.sidebar.number_input(
|
steps = st.number_input(
|
||||||
f"steps #{key}", value=50, min_value=1, max_value=1000
|
f"steps #{key}", value=options.get("num_steps", 50), min_value=1, max_value=1000
|
||||||
)
|
)
|
||||||
sampler = st.sidebar.selectbox(
|
sampler = st.sidebar.selectbox(
|
||||||
f"Sampler #{key}",
|
f"Sampler #{key}",
|
||||||
@@ -296,7 +331,7 @@ def init_sampling(
|
|||||||
"DPMPP2MSampler",
|
"DPMPP2MSampler",
|
||||||
"LinearMultistepSampler",
|
"LinearMultistepSampler",
|
||||||
],
|
],
|
||||||
0,
|
options.get("sampler", 0),
|
||||||
)
|
)
|
||||||
discretization = st.sidebar.selectbox(
|
discretization = st.sidebar.selectbox(
|
||||||
f"Discretization #{key}",
|
f"Discretization #{key}",
|
||||||
@@ -304,36 +339,41 @@ def init_sampling(
|
|||||||
"LegacyDDPMDiscretization",
|
"LegacyDDPMDiscretization",
|
||||||
"EDMDiscretization",
|
"EDMDiscretization",
|
||||||
],
|
],
|
||||||
|
options.get("discretization", 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
discretization_config = get_discretization(discretization, key=key)
|
discretization_config = get_discretization(discretization, options=options, key=key)
|
||||||
|
|
||||||
guider_config = get_guider(key=key)
|
guider_config = get_guider(options=options, key=key)
|
||||||
|
|
||||||
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
|
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
|
||||||
if img2img_strength < 1.0:
|
if img2img_strength is not None:
|
||||||
st.warning(
|
st.warning(
|
||||||
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
|
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
|
||||||
)
|
)
|
||||||
sampler.discretization = Img2ImgDiscretizationWrapper(
|
sampler.discretization = Img2ImgDiscretizationWrapper(
|
||||||
sampler.discretization, strength=img2img_strength
|
sampler.discretization, strength=img2img_strength
|
||||||
)
|
)
|
||||||
if get_num_samples:
|
if stage2strength is not None:
|
||||||
return num_rows, num_cols, sampler
|
sampler.discretization = Txt2NoisyDiscretizationWrapper(
|
||||||
return sampler
|
sampler.discretization, strength=stage2strength, original_steps=steps
|
||||||
|
)
|
||||||
|
return sampler, num_rows, num_cols
|
||||||
|
|
||||||
|
|
||||||
def get_discretization(discretization, key=1):
|
def get_discretization(discretization, options, key=1):
|
||||||
if discretization == "LegacyDDPMDiscretization":
|
if discretization == "LegacyDDPMDiscretization":
|
||||||
use_new_range = st.checkbox(f"Start from highest noise level? #{key}", False)
|
|
||||||
discretization_config = {
|
discretization_config = {
|
||||||
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
||||||
"params": {"legacy_range": not use_new_range},
|
|
||||||
}
|
}
|
||||||
elif discretization == "EDMDiscretization":
|
elif discretization == "EDMDiscretization":
|
||||||
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
|
sigma_min = st.sidebar.number_input(
|
||||||
sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146
|
f"sigma_min #{key}", value=options.get("sigma_min", 0.03)
|
||||||
rho = st.number_input(f"rho #{key}", value=3.0)
|
) # 0.0292
|
||||||
|
sigma_max = st.sidebar.number_input(
|
||||||
|
f"sigma_max #{key}", value=options.get("sigma_max", 14.61)
|
||||||
|
) # 14.6146
|
||||||
|
rho = st.sidebar.number_input(f"rho #{key}", value=options.get("rho", 3.0))
|
||||||
discretization_config = {
|
discretization_config = {
|
||||||
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
||||||
"params": {
|
"params": {
|
||||||
@@ -422,8 +462,8 @@ def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1
|
|||||||
return sampler
|
return sampler
|
||||||
|
|
||||||
|
|
||||||
def get_interactive_image(key=None) -> Image.Image:
|
def get_interactive_image() -> Image.Image:
|
||||||
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
|
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
|
||||||
if image is not None:
|
if image is not None:
|
||||||
image = Image.open(image)
|
image = Image.open(image)
|
||||||
if not image.mode == "RGB":
|
if not image.mode == "RGB":
|
||||||
@@ -431,8 +471,12 @@ def get_interactive_image(key=None) -> Image.Image:
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def load_img(display=True, key=None):
|
def load_img(
|
||||||
image = get_interactive_image(key=key)
|
display: bool = True,
|
||||||
|
size: Union[None, int, Tuple[int, int]] = None,
|
||||||
|
center_crop: bool = False,
|
||||||
|
):
|
||||||
|
image = get_interactive_image()
|
||||||
if image is None:
|
if image is None:
|
||||||
return None
|
return None
|
||||||
if display:
|
if display:
|
||||||
@@ -440,12 +484,15 @@ def load_img(display=True, key=None):
|
|||||||
w, h = image.size
|
w, h = image.size
|
||||||
print(f"loaded input image of size ({w}, {h})")
|
print(f"loaded input image of size ({w}, {h})")
|
||||||
|
|
||||||
transform = transforms.Compose(
|
transform = []
|
||||||
[
|
if size is not None:
|
||||||
transforms.ToTensor(),
|
transform.append(transforms.Resize(size))
|
||||||
transforms.Lambda(lambda x: x * 2.0 - 1.0),
|
if center_crop:
|
||||||
]
|
transform.append(transforms.CenterCrop(size))
|
||||||
)
|
transform.append(transforms.ToTensor())
|
||||||
|
transform.append(transforms.Lambda(lambda x: 2.0 * x - 1.0))
|
||||||
|
|
||||||
|
transform = transforms.Compose(transform)
|
||||||
img = transform(image)[None, ...]
|
img = transform(image)[None, ...]
|
||||||
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
|
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
|
||||||
return img
|
return img
|
||||||
@@ -466,15 +513,18 @@ def do_sample(
|
|||||||
W,
|
W,
|
||||||
C,
|
C,
|
||||||
F,
|
F,
|
||||||
force_uc_zero_embeddings: List = None,
|
force_uc_zero_embeddings: Optional[List] = None,
|
||||||
|
force_cond_zero_embeddings: Optional[List] = None,
|
||||||
batch2model_input: List = None,
|
batch2model_input: List = None,
|
||||||
return_latents=False,
|
return_latents=False,
|
||||||
filter=None,
|
filter=None,
|
||||||
|
T=None,
|
||||||
|
additional_batch_uc_fields=None,
|
||||||
|
decoding_t=None,
|
||||||
):
|
):
|
||||||
if force_uc_zero_embeddings is None:
|
force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])
|
||||||
force_uc_zero_embeddings = []
|
batch2model_input = default(batch2model_input, [])
|
||||||
if batch2model_input is None:
|
additional_batch_uc_fields = default(additional_batch_uc_fields, [])
|
||||||
batch2model_input = []
|
|
||||||
|
|
||||||
st.text("Sampling")
|
st.text("Sampling")
|
||||||
|
|
||||||
@@ -483,34 +533,61 @@ def do_sample(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with precision_scope("cuda"):
|
with precision_scope("cuda"):
|
||||||
with model.ema_scope():
|
with model.ema_scope():
|
||||||
num_samples = [num_samples]
|
if T is not None:
|
||||||
|
num_samples = [num_samples, T]
|
||||||
|
else:
|
||||||
|
num_samples = [num_samples]
|
||||||
|
|
||||||
|
load_model(model.conditioner)
|
||||||
batch, batch_uc = get_batch(
|
batch, batch_uc = get_batch(
|
||||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
value_dict,
|
value_dict,
|
||||||
num_samples,
|
num_samples,
|
||||||
|
T=T,
|
||||||
|
additional_batch_uc_fields=additional_batch_uc_fields,
|
||||||
)
|
)
|
||||||
for key in batch:
|
|
||||||
if isinstance(batch[key], torch.Tensor):
|
|
||||||
print(key, batch[key].shape)
|
|
||||||
elif isinstance(batch[key], list):
|
|
||||||
print(key, [len(l) for l in batch[key]])
|
|
||||||
else:
|
|
||||||
print(key, batch[key])
|
|
||||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||||
batch,
|
batch,
|
||||||
batch_uc=batch_uc,
|
batch_uc=batch_uc,
|
||||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||||
|
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
||||||
)
|
)
|
||||||
|
unload_model(model.conditioner)
|
||||||
|
|
||||||
for k in c:
|
for k in c:
|
||||||
if not k == "crossattn":
|
if not k == "crossattn":
|
||||||
c[k], uc[k] = map(
|
c[k], uc[k] = map(
|
||||||
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
|
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
|
||||||
)
|
)
|
||||||
|
if k in ["crossattn", "concat"] and T is not None:
|
||||||
|
uc[k] = repeat(uc[k], "b ... -> b t ...", t=T)
|
||||||
|
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=T)
|
||||||
|
c[k] = repeat(c[k], "b ... -> b t ...", t=T)
|
||||||
|
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=T)
|
||||||
|
|
||||||
additional_model_inputs = {}
|
additional_model_inputs = {}
|
||||||
for k in batch2model_input:
|
for k in batch2model_input:
|
||||||
additional_model_inputs[k] = batch[k]
|
if k == "image_only_indicator":
|
||||||
|
assert T is not None
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
sampler.guider,
|
||||||
|
(
|
||||||
|
VanillaCFG,
|
||||||
|
LinearPredictionGuider,
|
||||||
|
TrianglePredictionGuider,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
additional_model_inputs[k] = torch.zeros(
|
||||||
|
num_samples[0] * 2, num_samples[1]
|
||||||
|
).to("cuda")
|
||||||
|
else:
|
||||||
|
additional_model_inputs[k] = torch.zeros(num_samples).to(
|
||||||
|
"cuda"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
additional_model_inputs[k] = batch[k]
|
||||||
|
|
||||||
shape = (math.prod(num_samples), C, H // F, W // F)
|
shape = (math.prod(num_samples), C, H // F, W // F)
|
||||||
randn = torch.randn(shape).to("cuda")
|
randn = torch.randn(shape).to("cuda")
|
||||||
@@ -520,23 +597,49 @@ def do_sample(
|
|||||||
model.model, input, sigma, c, **additional_model_inputs
|
model.model, input, sigma, c, **additional_model_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
load_model(model.denoiser)
|
||||||
|
load_model(model.model)
|
||||||
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||||
|
unload_model(model.model)
|
||||||
|
unload_model(model.denoiser)
|
||||||
|
|
||||||
|
load_model(model.first_stage_model)
|
||||||
|
model.en_and_decode_n_samples_a_time = (
|
||||||
|
decoding_t # Decode n frames at a time
|
||||||
|
)
|
||||||
samples_x = model.decode_first_stage(samples_z)
|
samples_x = model.decode_first_stage(samples_z)
|
||||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
unload_model(model.first_stage_model)
|
||||||
|
|
||||||
if filter is not None:
|
if filter is not None:
|
||||||
samples = filter(samples)
|
samples = filter(samples)
|
||||||
|
|
||||||
grid = torch.stack([samples])
|
if T is None:
|
||||||
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
grid = torch.stack([samples])
|
||||||
outputs.image(grid.cpu().numpy())
|
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
||||||
|
outputs.image(grid.cpu().numpy())
|
||||||
|
else:
|
||||||
|
as_vids = rearrange(samples, "(b t) c h w -> b t c h w", t=T)
|
||||||
|
for i, vid in enumerate(as_vids):
|
||||||
|
grid = rearrange(make_grid(vid, nrow=4), "c h w -> h w c")
|
||||||
|
st.image(
|
||||||
|
grid.cpu().numpy(),
|
||||||
|
f"Sample #{i} as image",
|
||||||
|
)
|
||||||
|
|
||||||
if return_latents:
|
if return_latents:
|
||||||
return samples, samples_z
|
return samples, samples_z
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
def get_batch(
|
||||||
|
keys,
|
||||||
|
value_dict: dict,
|
||||||
|
N: Union[List, ListConfig],
|
||||||
|
device: str = "cuda",
|
||||||
|
T: int = None,
|
||||||
|
additional_batch_uc_fields: List[str] = [],
|
||||||
|
):
|
||||||
# Hardcoded demo setups; might undergo some changes in the future
|
# Hardcoded demo setups; might undergo some changes in the future
|
||||||
|
|
||||||
batch = {}
|
batch = {}
|
||||||
@@ -544,21 +647,15 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
|||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key == "txt":
|
if key == "txt":
|
||||||
batch["txt"] = (
|
batch["txt"] = [value_dict["prompt"]] * math.prod(N)
|
||||||
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
|
||||||
.reshape(N)
|
batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N)
|
||||||
.tolist()
|
|
||||||
)
|
|
||||||
batch_uc["txt"] = (
|
|
||||||
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
|
||||||
.reshape(N)
|
|
||||||
.tolist()
|
|
||||||
)
|
|
||||||
elif key == "original_size_as_tuple":
|
elif key == "original_size_as_tuple":
|
||||||
batch["original_size_as_tuple"] = (
|
batch["original_size_as_tuple"] = (
|
||||||
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
||||||
.to(device)
|
.to(device)
|
||||||
.repeat(*N, 1)
|
.repeat(math.prod(N), 1)
|
||||||
)
|
)
|
||||||
elif key == "crop_coords_top_left":
|
elif key == "crop_coords_top_left":
|
||||||
batch["crop_coords_top_left"] = (
|
batch["crop_coords_top_left"] = (
|
||||||
@@ -566,30 +663,73 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
|||||||
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
||||||
)
|
)
|
||||||
.to(device)
|
.to(device)
|
||||||
.repeat(*N, 1)
|
.repeat(math.prod(N), 1)
|
||||||
)
|
)
|
||||||
elif key == "aesthetic_score":
|
elif key == "aesthetic_score":
|
||||||
batch["aesthetic_score"] = (
|
batch["aesthetic_score"] = (
|
||||||
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
torch.tensor([value_dict["aesthetic_score"]])
|
||||||
|
.to(device)
|
||||||
|
.repeat(math.prod(N), 1)
|
||||||
)
|
)
|
||||||
batch_uc["aesthetic_score"] = (
|
batch_uc["aesthetic_score"] = (
|
||||||
torch.tensor([value_dict["negative_aesthetic_score"]])
|
torch.tensor([value_dict["negative_aesthetic_score"]])
|
||||||
.to(device)
|
.to(device)
|
||||||
.repeat(*N, 1)
|
.repeat(math.prod(N), 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif key == "target_size_as_tuple":
|
elif key == "target_size_as_tuple":
|
||||||
batch["target_size_as_tuple"] = (
|
batch["target_size_as_tuple"] = (
|
||||||
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
||||||
.to(device)
|
.to(device)
|
||||||
.repeat(*N, 1)
|
.repeat(math.prod(N), 1)
|
||||||
|
)
|
||||||
|
elif key == "fps":
|
||||||
|
batch[key] = (
|
||||||
|
torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
|
||||||
|
)
|
||||||
|
elif key == "fps_id":
|
||||||
|
batch[key] = (
|
||||||
|
torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
|
||||||
|
)
|
||||||
|
elif key == "motion_bucket_id":
|
||||||
|
batch[key] = (
|
||||||
|
torch.tensor([value_dict["motion_bucket_id"]])
|
||||||
|
.to(device)
|
||||||
|
.repeat(math.prod(N))
|
||||||
|
)
|
||||||
|
elif key == "pool_image":
|
||||||
|
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
|
||||||
|
device, dtype=torch.half
|
||||||
|
)
|
||||||
|
elif key == "cond_aug":
|
||||||
|
batch[key] = repeat(
|
||||||
|
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
|
||||||
|
"1 -> b",
|
||||||
|
b=math.prod(N),
|
||||||
|
)
|
||||||
|
elif key == "cond_frames":
|
||||||
|
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
|
||||||
|
elif key == "cond_frames_without_noise":
|
||||||
|
batch[key] = repeat(
|
||||||
|
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
|
||||||
|
)
|
||||||
|
elif key == "polars_rad":
|
||||||
|
batch[key] = torch.tensor(value_dict["polars_rad"]).to(device).repeat(N[0])
|
||||||
|
elif key == "azimuths_rad":
|
||||||
|
batch[key] = (
|
||||||
|
torch.tensor(value_dict["azimuths_rad"]).to(device).repeat(N[0])
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch[key] = value_dict[key]
|
batch[key] = value_dict[key]
|
||||||
|
|
||||||
|
if T is not None:
|
||||||
|
batch["num_video_frames"] = T
|
||||||
|
|
||||||
for key in batch.keys():
|
for key in batch.keys():
|
||||||
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
||||||
batch_uc[key] = torch.clone(batch[key])
|
batch_uc[key] = torch.clone(batch[key])
|
||||||
|
elif key in additional_batch_uc_fields and key not in batch_uc:
|
||||||
|
batch_uc[key] = copy.copy(batch[key])
|
||||||
return batch, batch_uc
|
return batch, batch_uc
|
||||||
|
|
||||||
|
|
||||||
@@ -600,12 +740,14 @@ def do_img2img(
|
|||||||
sampler,
|
sampler,
|
||||||
value_dict,
|
value_dict,
|
||||||
num_samples,
|
num_samples,
|
||||||
force_uc_zero_embeddings=[],
|
force_uc_zero_embeddings: Optional[List] = None,
|
||||||
|
force_cond_zero_embeddings: Optional[List] = None,
|
||||||
additional_kwargs={},
|
additional_kwargs={},
|
||||||
offset_noise_level: int = 0.0,
|
offset_noise_level: int = 0.0,
|
||||||
return_latents=False,
|
return_latents=False,
|
||||||
skip_encode=False,
|
skip_encode=False,
|
||||||
filter=None,
|
filter=None,
|
||||||
|
add_noise=True,
|
||||||
):
|
):
|
||||||
st.text("Sampling")
|
st.text("Sampling")
|
||||||
|
|
||||||
@@ -614,6 +756,7 @@ def do_img2img(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with precision_scope("cuda"):
|
with precision_scope("cuda"):
|
||||||
with model.ema_scope():
|
with model.ema_scope():
|
||||||
|
load_model(model.conditioner)
|
||||||
batch, batch_uc = get_batch(
|
batch, batch_uc = get_batch(
|
||||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
value_dict,
|
value_dict,
|
||||||
@@ -623,8 +766,9 @@ def do_img2img(
|
|||||||
batch,
|
batch,
|
||||||
batch_uc=batch_uc,
|
batch_uc=batch_uc,
|
||||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||||
|
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
||||||
)
|
)
|
||||||
|
unload_model(model.conditioner)
|
||||||
for k in c:
|
for k in c:
|
||||||
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
|
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
|
||||||
|
|
||||||
@@ -633,36 +777,145 @@ def do_img2img(
|
|||||||
if skip_encode:
|
if skip_encode:
|
||||||
z = img
|
z = img
|
||||||
else:
|
else:
|
||||||
|
load_model(model.first_stage_model)
|
||||||
z = model.encode_first_stage(img)
|
z = model.encode_first_stage(img)
|
||||||
|
unload_model(model.first_stage_model)
|
||||||
|
|
||||||
noise = torch.randn_like(z)
|
noise = torch.randn_like(z)
|
||||||
sigmas = sampler.discretization(sampler.num_steps)
|
|
||||||
|
sigmas = sampler.discretization(sampler.num_steps).cuda()
|
||||||
sigma = sigmas[0]
|
sigma = sigmas[0]
|
||||||
|
|
||||||
st.info(f"all sigmas: {sigmas}")
|
st.info(f"all sigmas: {sigmas}")
|
||||||
st.info(f"noising sigma: {sigma}")
|
st.info(f"noising sigma: {sigma}")
|
||||||
|
|
||||||
if offset_noise_level > 0.0:
|
if offset_noise_level > 0.0:
|
||||||
noise = noise + offset_noise_level * append_dims(
|
noise = noise + offset_noise_level * append_dims(
|
||||||
torch.randn(z.shape[0], device=z.device), z.ndim
|
torch.randn(z.shape[0], device=z.device), z.ndim
|
||||||
)
|
)
|
||||||
noised_z = z + noise * append_dims(sigma, z.ndim)
|
if add_noise:
|
||||||
noised_z = noised_z / torch.sqrt(
|
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
|
||||||
1.0 + sigmas[0] ** 2.0
|
noised_z = noised_z / torch.sqrt(
|
||||||
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
1.0 + sigmas[0] ** 2.0
|
||||||
|
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
||||||
|
else:
|
||||||
|
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||||
|
|
||||||
def denoiser(x, sigma, c):
|
def denoiser(x, sigma, c):
|
||||||
return model.denoiser(model.model, x, sigma, c)
|
return model.denoiser(model.model, x, sigma, c)
|
||||||
|
|
||||||
|
load_model(model.denoiser)
|
||||||
|
load_model(model.model)
|
||||||
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
||||||
|
unload_model(model.model)
|
||||||
|
unload_model(model.denoiser)
|
||||||
|
|
||||||
|
load_model(model.first_stage_model)
|
||||||
samples_x = model.decode_first_stage(samples_z)
|
samples_x = model.decode_first_stage(samples_z)
|
||||||
|
unload_model(model.first_stage_model)
|
||||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
if filter is not None:
|
if filter is not None:
|
||||||
samples = filter(samples)
|
samples = filter(samples)
|
||||||
|
|
||||||
grid = embed_watemark(torch.stack([samples]))
|
|
||||||
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
||||||
outputs.image(grid.cpu().numpy())
|
outputs.image(grid.cpu().numpy())
|
||||||
if return_latents:
|
if return_latents:
|
||||||
return samples, samples_z
|
return samples, samples_z
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def get_resizing_factor(
|
||||||
|
desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
|
||||||
|
) -> float:
|
||||||
|
r_bound = desired_shape[1] / desired_shape[0]
|
||||||
|
aspect_r = current_shape[1] / current_shape[0]
|
||||||
|
if r_bound >= 1.0:
|
||||||
|
if aspect_r >= r_bound:
|
||||||
|
factor = min(desired_shape) / min(current_shape)
|
||||||
|
else:
|
||||||
|
if aspect_r < 1.0:
|
||||||
|
factor = max(desired_shape) / min(current_shape)
|
||||||
|
else:
|
||||||
|
factor = max(desired_shape) / max(current_shape)
|
||||||
|
else:
|
||||||
|
if aspect_r <= r_bound:
|
||||||
|
factor = min(desired_shape) / min(current_shape)
|
||||||
|
else:
|
||||||
|
if aspect_r > 1:
|
||||||
|
factor = max(desired_shape) / min(current_shape)
|
||||||
|
else:
|
||||||
|
factor = max(desired_shape) / max(current_shape)
|
||||||
|
|
||||||
|
return factor
|
||||||
|
|
||||||
|
|
||||||
|
def get_interactive_image(key=None) -> Image.Image:
|
||||||
|
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
|
||||||
|
if image is not None:
|
||||||
|
image = Image.open(image)
|
||||||
|
if not image.mode == "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def load_img_for_prediction(
|
||||||
|
W: int, H: int, display=True, key=None, device="cuda"
|
||||||
|
) -> torch.Tensor:
|
||||||
|
image = get_interactive_image(key=key)
|
||||||
|
if image is None:
|
||||||
|
return None
|
||||||
|
if display:
|
||||||
|
st.image(image)
|
||||||
|
w, h = image.size
|
||||||
|
|
||||||
|
image = np.array(image).astype(np.float32) / 255
|
||||||
|
if image.shape[-1] == 4:
|
||||||
|
rgb, alpha = image[:, :, :3], image[:, :, 3:]
|
||||||
|
image = rgb * alpha + (1 - alpha)
|
||||||
|
|
||||||
|
image = image.transpose(2, 0, 1)
|
||||||
|
image = torch.from_numpy(image).to(dtype=torch.float32)
|
||||||
|
image = image.unsqueeze(0)
|
||||||
|
|
||||||
|
rfs = get_resizing_factor((H, W), (h, w))
|
||||||
|
resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
|
||||||
|
top = (resize_size[0] - H) // 2
|
||||||
|
left = (resize_size[1] - W) // 2
|
||||||
|
|
||||||
|
image = torch.nn.functional.interpolate(
|
||||||
|
image, resize_size, mode="area", antialias=False
|
||||||
|
)
|
||||||
|
image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
|
||||||
|
|
||||||
|
if display:
|
||||||
|
numpy_img = np.transpose(image[0].numpy(), (1, 2, 0))
|
||||||
|
pil_image = Image.fromarray((numpy_img * 255).astype(np.uint8))
|
||||||
|
st.image(pil_image)
|
||||||
|
return image.to(device) * 2.0 - 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def save_video_as_grid_and_mp4(
|
||||||
|
video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5
|
||||||
|
):
|
||||||
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
base_count = len(glob(os.path.join(save_path, "*.mp4")))
|
||||||
|
|
||||||
|
video_batch = rearrange(video_batch, "(b t) c h w -> b t c h w", t=T)
|
||||||
|
video_batch = embed_watermark(video_batch)
|
||||||
|
for vid in video_batch:
|
||||||
|
save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4)
|
||||||
|
|
||||||
|
video_path = os.path.join(save_path, f"{base_count:06d}.mp4")
|
||||||
|
vid = (
|
||||||
|
(rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8)
|
||||||
|
)
|
||||||
|
imageio.mimwrite(video_path, vid, fps=fps)
|
||||||
|
|
||||||
|
video_path_h264 = video_path[:-4] + "_h264.mp4"
|
||||||
|
os.system(f"ffmpeg -i '{video_path}' -c:v libx264 '{video_path_h264}'")
|
||||||
|
with open(video_path_h264, "rb") as f:
|
||||||
|
video_bytes = f.read()
|
||||||
|
os.remove(video_path_h264)
|
||||||
|
st.video(video_bytes)
|
||||||
|
|
||||||
|
base_count += 1
|
||||||
|
|||||||
104
scripts/demo/sv3d_helpers.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def generate_dynamic_cycle_xy_values(
|
||||||
|
length=21,
|
||||||
|
init_elev=0,
|
||||||
|
num_components=84,
|
||||||
|
frequency_range=(1, 5),
|
||||||
|
amplitude_range=(0.5, 10),
|
||||||
|
step_range=(0, 2),
|
||||||
|
):
|
||||||
|
# Y values generation
|
||||||
|
y_sequence = np.ones(length) * init_elev
|
||||||
|
for _ in range(num_components):
|
||||||
|
# Choose a frequency that will complete whole cycles in the sequence
|
||||||
|
frequency = np.random.randint(*frequency_range) * (2 * np.pi / length)
|
||||||
|
amplitude = np.random.uniform(*amplitude_range)
|
||||||
|
phase_shift = np.random.choice([0, np.pi]) # np.random.uniform(0, 2 * np.pi)
|
||||||
|
angles = (
|
||||||
|
np.linspace(0, frequency * length, length, endpoint=False) + phase_shift
|
||||||
|
)
|
||||||
|
y_sequence += np.sin(angles) * amplitude
|
||||||
|
# X values generation
|
||||||
|
# Generate length - 1 steps since the last step is back to start
|
||||||
|
steps = np.random.uniform(*step_range, length - 1)
|
||||||
|
total_step_sum = np.sum(steps)
|
||||||
|
# Calculate the scale factor to scale total steps to just under 360
|
||||||
|
scale_factor = (
|
||||||
|
360 - ((360 / length) * np.random.uniform(*step_range))
|
||||||
|
) / total_step_sum
|
||||||
|
# Apply the scale factor and generate the sequence of X values
|
||||||
|
x_values = np.cumsum(steps * scale_factor)
|
||||||
|
# Ensure the sequence starts at 0 and add the final step to complete the loop
|
||||||
|
x_values = np.insert(x_values, 0, 0)
|
||||||
|
return x_values, y_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def smooth_data(data, window_size):
|
||||||
|
# Extend data at both ends by wrapping around to create a continuous loop
|
||||||
|
pad_size = window_size
|
||||||
|
padded_data = np.concatenate((data[-pad_size:], data, data[:pad_size]))
|
||||||
|
|
||||||
|
# Apply smoothing
|
||||||
|
kernel = np.ones(window_size) / window_size
|
||||||
|
smoothed_data = np.convolve(padded_data, kernel, mode="same")
|
||||||
|
|
||||||
|
# Extract the smoothed data corresponding to the original sequence
|
||||||
|
# Adjust the indices to account for the larger padding
|
||||||
|
start_index = pad_size
|
||||||
|
end_index = -pad_size if pad_size != 0 else None
|
||||||
|
smoothed_original_data = smoothed_data[start_index:end_index]
|
||||||
|
return smoothed_original_data
|
||||||
|
|
||||||
|
|
||||||
|
# Function to generate and process the data
|
||||||
|
def gen_dynamic_loop(length=21, elev_deg=0):
|
||||||
|
while True:
|
||||||
|
# Generate the combined X and Y values using the new function
|
||||||
|
azim_values, elev_values = generate_dynamic_cycle_xy_values(
|
||||||
|
length=84, init_elev=elev_deg
|
||||||
|
)
|
||||||
|
# Smooth the Y values directly
|
||||||
|
smoothed_elev_values = smooth_data(elev_values, 5)
|
||||||
|
max_magnitude = np.max(np.abs(smoothed_elev_values))
|
||||||
|
if max_magnitude < 90:
|
||||||
|
break
|
||||||
|
subsample = 84 // length
|
||||||
|
azim_rad = np.deg2rad(azim_values[::subsample])
|
||||||
|
elev_rad = np.deg2rad(smoothed_elev_values[::subsample])
|
||||||
|
# Make cond frame the last one
|
||||||
|
return np.roll(azim_rad, -1), np.roll(elev_rad, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_3D(azim, polar, save_path, dynamic=True):
|
||||||
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
|
elev = np.deg2rad(90) - polar
|
||||||
|
fig = plt.figure(figsize=(5, 5))
|
||||||
|
ax = fig.add_subplot(projection="3d")
|
||||||
|
cm = plt.get_cmap("Greys")
|
||||||
|
col_line = [cm(i) for i in np.linspace(0.3, 1, len(azim) + 1)]
|
||||||
|
cm = plt.get_cmap("cool")
|
||||||
|
col = [cm(float(i) / (len(azim))) for i in np.arange(len(azim))]
|
||||||
|
xs = np.cos(elev) * np.cos(azim)
|
||||||
|
ys = np.cos(elev) * np.sin(azim)
|
||||||
|
zs = np.sin(elev)
|
||||||
|
ax.scatter(xs[0], ys[0], zs[0], s=100, color=col[0])
|
||||||
|
xs_d, ys_d, zs_d = (xs[1:] - xs[:-1]), (ys[1:] - ys[:-1]), (zs[1:] - zs[:-1])
|
||||||
|
for i in range(len(xs) - 1):
|
||||||
|
if dynamic:
|
||||||
|
ax.quiver(
|
||||||
|
xs[i], ys[i], zs[i], xs_d[i], ys_d[i], zs_d[i], lw=2, color=col_line[i]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax.plot(xs[i : i + 2], ys[i : i + 2], zs[i : i + 2], lw=2, c=col_line[i])
|
||||||
|
ax.scatter(xs[i + 1], ys[i + 1], zs[i + 1], s=100, color=col[i + 1])
|
||||||
|
ax.scatter(xs[:1], ys[:1], zs[:1], s=120, facecolors="none", edgecolors="k")
|
||||||
|
ax.scatter(xs[-1:], ys[-1:], zs[-1:], s=120, facecolors="none", edgecolors="k")
|
||||||
|
ax.view_init(elev=30, azim=-20, roll=0)
|
||||||
|
plt.savefig(save_path, bbox_inches="tight")
|
||||||
|
plt.clf()
|
||||||
|
plt.close()
|
||||||
1415
scripts/demo/sv4d_helpers.py
Executable file
225
scripts/demo/turbo.py
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
from st_keyup import st_keyup
|
||||||
|
from streamlit_helpers import *
|
||||||
|
|
||||||
|
from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler
|
||||||
|
|
||||||
|
VERSION2SPECS = {
|
||||||
|
"SDXL-Turbo": {
|
||||||
|
"H": 512,
|
||||||
|
"W": 512,
|
||||||
|
"C": 4,
|
||||||
|
"f": 8,
|
||||||
|
"is_legacy": False,
|
||||||
|
"config": "configs/inference/sd_xl_base.yaml",
|
||||||
|
"ckpt": "checkpoints/sd_xl_turbo_1.0.safetensors",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SubstepSampler(EulerAncestralSampler):
|
||||||
|
def __init__(self, n_sample_steps=1, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.n_sample_steps = n_sample_steps
|
||||||
|
self.steps_subset = [0, 100, 200, 300, 1000]
|
||||||
|
|
||||||
|
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
||||||
|
sigmas = self.discretization(
|
||||||
|
self.num_steps if num_steps is None else num_steps, device=self.device
|
||||||
|
)
|
||||||
|
sigmas = sigmas[
|
||||||
|
self.steps_subset[: self.n_sample_steps] + self.steps_subset[-1:]
|
||||||
|
]
|
||||||
|
uc = cond
|
||||||
|
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||||
|
num_sigmas = len(sigmas)
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
return x, s_in, sigmas, num_sigmas, cond, uc
|
||||||
|
|
||||||
|
|
||||||
|
def seeded_randn(shape, seed):
|
||||||
|
randn = np.random.RandomState(seed).randn(*shape)
|
||||||
|
randn = torch.from_numpy(randn).to(device="cuda", dtype=torch.float32)
|
||||||
|
return randn
|
||||||
|
|
||||||
|
|
||||||
|
class SeededNoise:
|
||||||
|
def __init__(self, seed):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
self.seed = self.seed + 1
|
||||||
|
return seeded_randn(x.shape, self.seed)
|
||||||
|
|
||||||
|
|
||||||
|
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
||||||
|
value_dict = {}
|
||||||
|
for key in keys:
|
||||||
|
if key == "txt":
|
||||||
|
value_dict["prompt"] = prompt
|
||||||
|
value_dict["negative_prompt"] = ""
|
||||||
|
|
||||||
|
if key == "original_size_as_tuple":
|
||||||
|
orig_width = init_dict["orig_width"]
|
||||||
|
orig_height = init_dict["orig_height"]
|
||||||
|
|
||||||
|
value_dict["orig_width"] = orig_width
|
||||||
|
value_dict["orig_height"] = orig_height
|
||||||
|
|
||||||
|
if key == "crop_coords_top_left":
|
||||||
|
crop_coord_top = 0
|
||||||
|
crop_coord_left = 0
|
||||||
|
|
||||||
|
value_dict["crop_coords_top"] = crop_coord_top
|
||||||
|
value_dict["crop_coords_left"] = crop_coord_left
|
||||||
|
|
||||||
|
if key == "aesthetic_score":
|
||||||
|
value_dict["aesthetic_score"] = 6.0
|
||||||
|
value_dict["negative_aesthetic_score"] = 2.5
|
||||||
|
|
||||||
|
if key == "target_size_as_tuple":
|
||||||
|
value_dict["target_width"] = init_dict["target_width"]
|
||||||
|
value_dict["target_height"] = init_dict["target_height"]
|
||||||
|
|
||||||
|
return value_dict
|
||||||
|
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
model,
|
||||||
|
sampler,
|
||||||
|
prompt="A lush garden with oversized flowers and vibrant colors, inhabited by miniature animals.",
|
||||||
|
H=1024,
|
||||||
|
W=1024,
|
||||||
|
seed=0,
|
||||||
|
filter=None,
|
||||||
|
):
|
||||||
|
F = 8
|
||||||
|
C = 4
|
||||||
|
shape = (1, C, H // F, W // F)
|
||||||
|
|
||||||
|
value_dict = init_embedder_options(
|
||||||
|
keys=get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
|
init_dict={
|
||||||
|
"orig_width": W,
|
||||||
|
"orig_height": H,
|
||||||
|
"target_width": W,
|
||||||
|
"target_height": H,
|
||||||
|
},
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
if seed is None:
|
||||||
|
seed = torch.seed()
|
||||||
|
precision_scope = autocast
|
||||||
|
with torch.no_grad():
|
||||||
|
with precision_scope("cuda"):
|
||||||
|
batch, batch_uc = get_batch(
|
||||||
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
|
value_dict,
|
||||||
|
[1],
|
||||||
|
)
|
||||||
|
c = model.conditioner(batch)
|
||||||
|
uc = None
|
||||||
|
randn = seeded_randn(shape, seed)
|
||||||
|
|
||||||
|
def denoiser(input, sigma, c):
|
||||||
|
return model.denoiser(
|
||||||
|
model.model,
|
||||||
|
input,
|
||||||
|
sigma,
|
||||||
|
c,
|
||||||
|
)
|
||||||
|
|
||||||
|
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||||
|
samples_x = model.decode_first_stage(samples_z)
|
||||||
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
if filter is not None:
|
||||||
|
samples = filter(samples)
|
||||||
|
samples = (
|
||||||
|
(255 * samples)
|
||||||
|
.to(dtype=torch.uint8)
|
||||||
|
.permute(0, 2, 3, 1)
|
||||||
|
.detach()
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def v_spacer(height) -> None:
|
||||||
|
for _ in range(height):
|
||||||
|
st.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
st.title("Turbo")
|
||||||
|
|
||||||
|
head_cols = st.columns([1, 1, 1])
|
||||||
|
with head_cols[0]:
|
||||||
|
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||||
|
version_dict = VERSION2SPECS[version]
|
||||||
|
|
||||||
|
with head_cols[1]:
|
||||||
|
v_spacer(2)
|
||||||
|
if st.checkbox("Load Model"):
|
||||||
|
mode = "txt2img"
|
||||||
|
else:
|
||||||
|
mode = "skip"
|
||||||
|
|
||||||
|
if mode != "skip":
|
||||||
|
state = init_st(version_dict, load_filter=True)
|
||||||
|
if state["msg"]:
|
||||||
|
st.info(state["msg"])
|
||||||
|
model = state["model"]
|
||||||
|
load_model(model)
|
||||||
|
|
||||||
|
# seed
|
||||||
|
if "seed" not in st.session_state:
|
||||||
|
st.session_state.seed = 0
|
||||||
|
|
||||||
|
def increment_counter():
|
||||||
|
st.session_state.seed += 1
|
||||||
|
|
||||||
|
def decrement_counter():
|
||||||
|
if st.session_state.seed > 0:
|
||||||
|
st.session_state.seed -= 1
|
||||||
|
|
||||||
|
with head_cols[2]:
|
||||||
|
n_steps = st.number_input(label="number of steps", min_value=1, max_value=4)
|
||||||
|
|
||||||
|
sampler = SubstepSampler(
|
||||||
|
n_sample_steps=1,
|
||||||
|
num_steps=1000,
|
||||||
|
eta=1.0,
|
||||||
|
discretization_config=dict(
|
||||||
|
target="sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
sampler.n_sample_steps = n_steps
|
||||||
|
default_prompt = (
|
||||||
|
"A cinematic shot of a baby racoon wearing an intricate italian priest robe."
|
||||||
|
)
|
||||||
|
prompt = st_keyup(
|
||||||
|
"Enter a value", value=default_prompt, debounce=300, key="interactive_text"
|
||||||
|
)
|
||||||
|
|
||||||
|
cols = st.columns([1, 5, 1])
|
||||||
|
if mode != "skip":
|
||||||
|
with cols[0]:
|
||||||
|
v_spacer(14)
|
||||||
|
st.button("↩", on_click=decrement_counter)
|
||||||
|
with cols[2]:
|
||||||
|
v_spacer(14)
|
||||||
|
st.button("↪", on_click=increment_counter)
|
||||||
|
|
||||||
|
sampler.noise_sampler = SeededNoise(seed=st.session_state.seed)
|
||||||
|
out = sample(
|
||||||
|
model,
|
||||||
|
sampler,
|
||||||
|
H=512,
|
||||||
|
W=512,
|
||||||
|
seed=st.session_state.seed,
|
||||||
|
prompt=prompt,
|
||||||
|
filter=state.get("filter"),
|
||||||
|
)
|
||||||
|
with cols[1]:
|
||||||
|
st.image(out[0])
|
||||||
280
scripts/demo/video_sampling.py
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
|
from scripts.demo.streamlit_helpers import *
|
||||||
|
from scripts.demo.sv3d_helpers import *
|
||||||
|
|
||||||
|
SAVE_PATH = "outputs/demo/vid/"
|
||||||
|
|
||||||
|
VERSION2SPECS = {
|
||||||
|
"svd": {
|
||||||
|
"T": 14,
|
||||||
|
"H": 576,
|
||||||
|
"W": 1024,
|
||||||
|
"C": 4,
|
||||||
|
"f": 8,
|
||||||
|
"config": "configs/inference/svd.yaml",
|
||||||
|
"ckpt": "checkpoints/svd.safetensors",
|
||||||
|
"options": {
|
||||||
|
"discretization": 1,
|
||||||
|
"cfg": 2.5,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
"sigma_max": 700.0,
|
||||||
|
"rho": 7.0,
|
||||||
|
"guider": 2,
|
||||||
|
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||||
|
"num_steps": 25,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"svd_image_decoder": {
|
||||||
|
"T": 14,
|
||||||
|
"H": 576,
|
||||||
|
"W": 1024,
|
||||||
|
"C": 4,
|
||||||
|
"f": 8,
|
||||||
|
"config": "configs/inference/svd_image_decoder.yaml",
|
||||||
|
"ckpt": "checkpoints/svd_image_decoder.safetensors",
|
||||||
|
"options": {
|
||||||
|
"discretization": 1,
|
||||||
|
"cfg": 2.5,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
"sigma_max": 700.0,
|
||||||
|
"rho": 7.0,
|
||||||
|
"guider": 2,
|
||||||
|
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||||
|
"num_steps": 25,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"svd_xt": {
|
||||||
|
"T": 25,
|
||||||
|
"H": 576,
|
||||||
|
"W": 1024,
|
||||||
|
"C": 4,
|
||||||
|
"f": 8,
|
||||||
|
"config": "configs/inference/svd.yaml",
|
||||||
|
"ckpt": "checkpoints/svd_xt.safetensors",
|
||||||
|
"options": {
|
||||||
|
"discretization": 1,
|
||||||
|
"cfg": 3.0,
|
||||||
|
"min_cfg": 1.5,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
"sigma_max": 700.0,
|
||||||
|
"rho": 7.0,
|
||||||
|
"guider": 2,
|
||||||
|
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||||
|
"num_steps": 30,
|
||||||
|
"decoding_t": 14,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"svd_xt_image_decoder": {
|
||||||
|
"T": 25,
|
||||||
|
"H": 576,
|
||||||
|
"W": 1024,
|
||||||
|
"C": 4,
|
||||||
|
"f": 8,
|
||||||
|
"config": "configs/inference/svd_image_decoder.yaml",
|
||||||
|
"ckpt": "checkpoints/svd_xt_image_decoder.safetensors",
|
||||||
|
"options": {
|
||||||
|
"discretization": 1,
|
||||||
|
"cfg": 3.0,
|
||||||
|
"min_cfg": 1.5,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
"sigma_max": 700.0,
|
||||||
|
"rho": 7.0,
|
||||||
|
"guider": 2,
|
||||||
|
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||||
|
"num_steps": 30,
|
||||||
|
"decoding_t": 14,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"sv3d_u": {
|
||||||
|
"T": 21,
|
||||||
|
"H": 576,
|
||||||
|
"W": 576,
|
||||||
|
"C": 4,
|
||||||
|
"f": 8,
|
||||||
|
"config": "configs/inference/sv3d_u.yaml",
|
||||||
|
"ckpt": "checkpoints/sv3d_u.safetensors",
|
||||||
|
"options": {
|
||||||
|
"discretization": 1,
|
||||||
|
"cfg": 2.5,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
"sigma_max": 700.0,
|
||||||
|
"rho": 7.0,
|
||||||
|
"guider": 3,
|
||||||
|
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||||
|
"num_steps": 50,
|
||||||
|
"decoding_t": 14,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"sv3d_p": {
|
||||||
|
"T": 21,
|
||||||
|
"H": 576,
|
||||||
|
"W": 576,
|
||||||
|
"C": 4,
|
||||||
|
"f": 8,
|
||||||
|
"config": "configs/inference/sv3d_p.yaml",
|
||||||
|
"ckpt": "checkpoints/sv3d_p.safetensors",
|
||||||
|
"options": {
|
||||||
|
"discretization": 1,
|
||||||
|
"cfg": 2.5,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
"sigma_max": 700.0,
|
||||||
|
"rho": 7.0,
|
||||||
|
"guider": 3,
|
||||||
|
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||||
|
"num_steps": 50,
|
||||||
|
"decoding_t": 14,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
st.title("Stable Video Diffusion / SV3D")
|
||||||
|
version = st.selectbox(
|
||||||
|
"Model Version",
|
||||||
|
[k for k in VERSION2SPECS.keys()],
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
version_dict = VERSION2SPECS[version]
|
||||||
|
if st.checkbox("Load Model"):
|
||||||
|
mode = "img2vid"
|
||||||
|
else:
|
||||||
|
mode = "skip"
|
||||||
|
|
||||||
|
H = st.sidebar.number_input(
|
||||||
|
"H", value=version_dict["H"], min_value=64, max_value=2048
|
||||||
|
)
|
||||||
|
W = st.sidebar.number_input(
|
||||||
|
"W", value=version_dict["W"], min_value=64, max_value=2048
|
||||||
|
)
|
||||||
|
T = st.sidebar.number_input(
|
||||||
|
"T", value=version_dict["T"], min_value=0, max_value=128
|
||||||
|
)
|
||||||
|
C = version_dict["C"]
|
||||||
|
F = version_dict["f"]
|
||||||
|
options = version_dict["options"]
|
||||||
|
|
||||||
|
if mode != "skip":
|
||||||
|
state = init_st(version_dict, load_filter=True)
|
||||||
|
if state["msg"]:
|
||||||
|
st.info(state["msg"])
|
||||||
|
model = state["model"]
|
||||||
|
|
||||||
|
ukeys = set(
|
||||||
|
get_unique_embedder_keys_from_conditioner(state["model"].conditioner)
|
||||||
|
)
|
||||||
|
|
||||||
|
value_dict = init_embedder_options(
|
||||||
|
ukeys,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
if "fps" not in ukeys:
|
||||||
|
value_dict["fps"] = 10
|
||||||
|
|
||||||
|
value_dict["image_only_indicator"] = 0
|
||||||
|
|
||||||
|
if mode == "img2vid":
|
||||||
|
img = load_img_for_prediction(W, H)
|
||||||
|
if "sv3d" in version:
|
||||||
|
cond_aug = 1e-5
|
||||||
|
else:
|
||||||
|
cond_aug = st.number_input(
|
||||||
|
"Conditioning augmentation:", value=0.02, min_value=0.0
|
||||||
|
)
|
||||||
|
value_dict["cond_frames_without_noise"] = img
|
||||||
|
value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img)
|
||||||
|
value_dict["cond_aug"] = cond_aug
|
||||||
|
|
||||||
|
if "sv3d_p" in version:
|
||||||
|
elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90)
|
||||||
|
trajectory = st.selectbox(
|
||||||
|
"Trajectory",
|
||||||
|
["same elevation", "dynamic"],
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
if trajectory == "same elevation":
|
||||||
|
value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T)
|
||||||
|
value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:]
|
||||||
|
elif trajectory == "dynamic":
|
||||||
|
azim_rad, elev_rad = gen_dynamic_loop(length=21, elev_deg=elev_deg)
|
||||||
|
value_dict["polars_rad"] = np.deg2rad(90) - elev_rad
|
||||||
|
value_dict["azimuths_rad"] = azim_rad
|
||||||
|
elif "sv3d_u" in version:
|
||||||
|
elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90)
|
||||||
|
value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T)
|
||||||
|
value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:]
|
||||||
|
|
||||||
|
seed = st.sidebar.number_input(
|
||||||
|
"seed", value=23, min_value=0, max_value=int(1e9)
|
||||||
|
)
|
||||||
|
seed_everything(seed)
|
||||||
|
|
||||||
|
save_locally, save_path = init_save_locally(
|
||||||
|
os.path.join(SAVE_PATH, version), init_value=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if "sv3d" in version:
|
||||||
|
plot_save_path = os.path.join(save_path, "plot_3D.png")
|
||||||
|
plot_3D(
|
||||||
|
azim=value_dict["azimuths_rad"],
|
||||||
|
polar=value_dict["polars_rad"],
|
||||||
|
save_path=plot_save_path,
|
||||||
|
dynamic=("sv3d_p" in version),
|
||||||
|
)
|
||||||
|
st.image(
|
||||||
|
plot_save_path,
|
||||||
|
f"3D camera trajectory",
|
||||||
|
)
|
||||||
|
|
||||||
|
options["num_frames"] = T
|
||||||
|
|
||||||
|
sampler, num_rows, num_cols = init_sampling(options=options)
|
||||||
|
num_samples = num_rows * num_cols
|
||||||
|
|
||||||
|
decoding_t = st.number_input(
|
||||||
|
"Decode t frames at a time (set small if you are low on VRAM)",
|
||||||
|
value=options.get("decoding_t", T),
|
||||||
|
min_value=1,
|
||||||
|
max_value=int(1e9),
|
||||||
|
)
|
||||||
|
|
||||||
|
if st.checkbox("Overwrite fps in mp4 generator", False):
|
||||||
|
saving_fps = st.number_input(
|
||||||
|
f"saving video at fps:", value=value_dict["fps"], min_value=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
saving_fps = value_dict["fps"]
|
||||||
|
|
||||||
|
if st.button("Sample"):
|
||||||
|
out = do_sample(
|
||||||
|
model,
|
||||||
|
sampler,
|
||||||
|
value_dict,
|
||||||
|
num_samples,
|
||||||
|
H,
|
||||||
|
W,
|
||||||
|
C,
|
||||||
|
F,
|
||||||
|
T=T,
|
||||||
|
batch2model_input=["num_video_frames", "image_only_indicator"],
|
||||||
|
force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None),
|
||||||
|
force_cond_zero_embeddings=options.get(
|
||||||
|
"force_cond_zero_embeddings", None
|
||||||
|
),
|
||||||
|
return_latents=False,
|
||||||
|
decoding_t=decoding_t,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(out, (tuple, list)):
|
||||||
|
samples, samples_z = out
|
||||||
|
else:
|
||||||
|
samples = out
|
||||||
|
samples_z = None
|
||||||
|
|
||||||
|
if save_locally:
|
||||||
|
save_video_as_grid_and_mp4(samples, save_path, T, fps=saving_fps)
|
||||||
132
scripts/sampling/configs/sv3d_p.yaml
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
ckpt_path: checkpoints/sv3d_p.safetensors
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||||
|
params:
|
||||||
|
adm_in_channels: 1280
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
merge_strategy: learned_with_images
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- input_key: cond_frames_without_noise
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
params:
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
disable_encoder_autocast: True
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
is_ae: True
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
- input_key: cond_aug
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: polars_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
- input_key: azimuths_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Decoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||||
|
params:
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||||
|
params:
|
||||||
|
sigma_max: 700.0
|
||||||
|
|
||||||
|
guider_config:
|
||||||
|
target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider
|
||||||
|
params:
|
||||||
|
max_scale: 2.5
|
||||||
120
scripts/sampling/configs/sv3d_u.yaml
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
ckpt_path: checkpoints/sv3d_u.safetensors
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||||
|
params:
|
||||||
|
adm_in_channels: 256
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
merge_strategy: learned_with_images
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: cond_frames_without_noise
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
params:
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
disable_encoder_autocast: True
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
is_ae: True
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
- input_key: cond_aug
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Decoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||||
|
params:
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||||
|
params:
|
||||||
|
sigma_max: 700.0
|
||||||
|
|
||||||
|
guider_config:
|
||||||
|
target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider
|
||||||
|
params:
|
||||||
|
max_scale: 2.5
|
||||||
203
scripts/sampling/configs/sv4d.yaml
Executable file
@@ -0,0 +1,203 @@
|
|||||||
|
N_TIME: 5
|
||||||
|
N_VIEW: 8
|
||||||
|
N_FRAMES: 40
|
||||||
|
|
||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
en_and_decode_n_samples_a_time: 7
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
ckpt_path: checkpoints/sv4d.safetensors
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime
|
||||||
|
params:
|
||||||
|
adm_in_channels: 1280
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
context_dim: 1024
|
||||||
|
motion_context_dim: 4
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
in_channels: 8
|
||||||
|
legacy: False
|
||||||
|
model_channels: 320
|
||||||
|
num_classes: sequential
|
||||||
|
num_head_channels: 64
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_channels: 4
|
||||||
|
replicate_time_mix_bug: True
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
time_block_merge_factor: 0.0
|
||||||
|
time_block_merge_strategy: learned_with_images
|
||||||
|
time_kernel_size: [3, 1, 1]
|
||||||
|
time_mix_legacy: False
|
||||||
|
transformer_depth: 1
|
||||||
|
use_checkpoint: False
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_motion_attention: True
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
|
||||||
|
- input_key: cond_frames_without_noise
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
is_trainable: False
|
||||||
|
params:
|
||||||
|
n_cond_frames: ${N_TIME}
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
is_trainable: False
|
||||||
|
params:
|
||||||
|
is_ae: True
|
||||||
|
n_cond_frames: ${N_FRAMES}
|
||||||
|
n_copies: 1
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
ddconfig:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
embed_dim: 4
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
monitor: val/rec_loss
|
||||||
|
sigma_cond_config:
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||||
|
|
||||||
|
- input_key: polar_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
- input_key: azimuth_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
- input_key: cond_view
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
is_ae: True
|
||||||
|
n_cond_frames: ${N_VIEW}
|
||||||
|
n_copies: 1
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||||
|
|
||||||
|
- input_key: cond_motion
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
is_ae: True
|
||||||
|
n_cond_frames: ${N_TIME}
|
||||||
|
n_copies: 1
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Decoder
|
||||||
|
params:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
|
||||||
|
sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||||
|
params:
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||||
|
params:
|
||||||
|
sigma_max: 500.0
|
||||||
|
guider_config:
|
||||||
|
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
||||||
|
params:
|
||||||
|
max_scale: 2.5
|
||||||
|
num_frames: ${N_FRAMES}
|
||||||
|
additional_cond_keys: [ cond_view, cond_motion ]
|
||||||
208
scripts/sampling/configs/sv4d2.yaml
Executable file
@@ -0,0 +1,208 @@
|
|||||||
|
N_TIME: 12
|
||||||
|
N_VIEW: 4
|
||||||
|
N_FRAMES: 48
|
||||||
|
|
||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
en_and_decode_n_samples_a_time: 8
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
ckpt_path: checkpoints/sv4d2.safetensors
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime
|
||||||
|
params:
|
||||||
|
adm_in_channels: 1280
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
context_dim: 1024
|
||||||
|
motion_context_dim: 4
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
in_channels: 8
|
||||||
|
legacy: False
|
||||||
|
model_channels: 320
|
||||||
|
num_classes: sequential
|
||||||
|
num_head_channels: 64
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_channels: 4
|
||||||
|
replicate_time_mix_bug: True
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
time_block_merge_factor: 0.0
|
||||||
|
time_block_merge_strategy: learned_with_images
|
||||||
|
time_kernel_size: [3, 1, 1]
|
||||||
|
time_mix_legacy: False
|
||||||
|
transformer_depth: 1
|
||||||
|
use_checkpoint: False
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
use_spatial_transformer: True
|
||||||
|
separate_motion_merge_factor: True
|
||||||
|
use_motion_attention: True
|
||||||
|
use_3d_attention: True
|
||||||
|
use_camera_emb: True
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
|
||||||
|
- input_key: cond_frames_without_noise
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
is_trainable: False
|
||||||
|
params:
|
||||||
|
n_cond_frames: ${N_TIME}
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
is_trainable: False
|
||||||
|
params:
|
||||||
|
is_ae: True
|
||||||
|
n_cond_frames: ${N_FRAMES}
|
||||||
|
n_copies: 1
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
ddconfig:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
embed_dim: 4
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
monitor: val/rec_loss
|
||||||
|
sigma_cond_config:
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||||
|
|
||||||
|
- input_key: polar_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
- input_key: azimuth_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
- input_key: cond_view
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
is_ae: True
|
||||||
|
n_cond_frames: ${N_VIEW}
|
||||||
|
n_copies: 1
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||||
|
|
||||||
|
- input_key: cond_motion
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
is_ae: True
|
||||||
|
n_cond_frames: ${N_TIME}
|
||||||
|
n_copies: 1
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Decoder
|
||||||
|
params:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
|
||||||
|
sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||||
|
params:
|
||||||
|
num_steps: 50
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||||
|
params:
|
||||||
|
sigma_max: 500.0
|
||||||
|
guider_config:
|
||||||
|
target: sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider
|
||||||
|
params:
|
||||||
|
max_scale: 1.5
|
||||||
|
min_scale: 1.5
|
||||||
|
num_frames: ${N_FRAMES}
|
||||||
|
num_views: ${N_VIEW}
|
||||||
|
additional_cond_keys: [ cond_view, cond_motion ]
|
||||||
208
scripts/sampling/configs/sv4d2_8views.yaml
Executable file
@@ -0,0 +1,208 @@
|
|||||||
|
N_TIME: 5
|
||||||
|
N_VIEW: 8
|
||||||
|
N_FRAMES: 40
|
||||||
|
|
||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
en_and_decode_n_samples_a_time: 8
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
ckpt_path: checkpoints/sv4d2_8views.safetensors
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.SpatialUNetModelWithTime
|
||||||
|
params:
|
||||||
|
adm_in_channels: 1280
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
context_dim: 1024
|
||||||
|
motion_context_dim: 4
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
in_channels: 8
|
||||||
|
legacy: False
|
||||||
|
model_channels: 320
|
||||||
|
num_classes: sequential
|
||||||
|
num_head_channels: 64
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_channels: 4
|
||||||
|
replicate_time_mix_bug: True
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
time_block_merge_factor: 0.0
|
||||||
|
time_block_merge_strategy: learned_with_images
|
||||||
|
time_kernel_size: [3, 1, 1]
|
||||||
|
time_mix_legacy: False
|
||||||
|
transformer_depth: 1
|
||||||
|
use_checkpoint: False
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
use_spatial_transformer: True
|
||||||
|
separate_motion_merge_factor: True
|
||||||
|
use_motion_attention: True
|
||||||
|
use_3d_attention: False
|
||||||
|
use_camera_emb: True
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
|
||||||
|
- input_key: cond_frames_without_noise
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
is_trainable: False
|
||||||
|
params:
|
||||||
|
n_cond_frames: ${N_TIME}
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
is_trainable: False
|
||||||
|
params:
|
||||||
|
is_ae: True
|
||||||
|
n_cond_frames: ${N_FRAMES}
|
||||||
|
n_copies: 1
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
ddconfig:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
embed_dim: 4
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
monitor: val/rec_loss
|
||||||
|
sigma_cond_config:
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||||
|
|
||||||
|
- input_key: polar_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
- input_key: azimuth_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
- input_key: cond_view
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
is_ae: True
|
||||||
|
n_cond_frames: ${N_VIEW}
|
||||||
|
n_copies: 1
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||||
|
|
||||||
|
- input_key: cond_motion
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
is_ae: True
|
||||||
|
n_cond_frames: ${N_TIME}
|
||||||
|
n_copies: 1
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Decoder
|
||||||
|
params:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
|
||||||
|
sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||||
|
params:
|
||||||
|
num_steps: 50
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||||
|
params:
|
||||||
|
sigma_max: 500.0
|
||||||
|
guider_config:
|
||||||
|
target: sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider
|
||||||
|
params:
|
||||||
|
max_scale: 2.0
|
||||||
|
min_scale: 1.5
|
||||||
|
num_frames: ${N_FRAMES}
|
||||||
|
num_views: ${N_VIEW}
|
||||||
|
additional_cond_keys: [ cond_view, cond_motion ]
|
||||||
146
scripts/sampling/configs/svd.yaml
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
ckpt_path: checkpoints/svd.safetensors
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||||
|
params:
|
||||||
|
adm_in_channels: 768
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
merge_strategy: learned_with_images
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: cond_frames_without_noise
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
params:
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: fps_id
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: motion_bucket_id
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
disable_encoder_autocast: True
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
is_ae: True
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
- input_key: cond_aug
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Encoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||||
|
params:
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||||
|
params:
|
||||||
|
sigma_max: 700.0
|
||||||
|
|
||||||
|
guider_config:
|
||||||
|
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
||||||
|
params:
|
||||||
|
max_scale: 2.5
|
||||||
|
min_scale: 1.0
|
||||||
129
scripts/sampling/configs/svd_image_decoder.yaml
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
ckpt_path: checkpoints/svd_image_decoder.safetensors
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||||
|
params:
|
||||||
|
adm_in_channels: 768
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
merge_strategy: learned_with_images
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: cond_frames_without_noise
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
params:
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: fps_id
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: motion_bucket_id
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
disable_encoder_autocast: True
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
is_ae: True
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
- input_key: cond_aug
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||||
|
params:
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||||
|
params:
|
||||||
|
sigma_max: 700.0
|
||||||
|
|
||||||
|
guider_config:
|
||||||
|
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
||||||
|
params:
|
||||||
|
max_scale: 2.5
|
||||||
|
min_scale: 1.0
|
||||||
146
scripts/sampling/configs/svd_xt.yaml
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
ckpt_path: checkpoints/svd_xt.safetensors
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||||
|
params:
|
||||||
|
adm_in_channels: 768
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
merge_strategy: learned_with_images
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: cond_frames_without_noise
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
params:
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: fps_id
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: motion_bucket_id
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
disable_encoder_autocast: True
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
is_ae: True
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
- input_key: cond_aug
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Encoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||||
|
params:
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||||
|
params:
|
||||||
|
sigma_max: 700.0
|
||||||
|
|
||||||
|
guider_config:
|
||||||
|
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
||||||
|
params:
|
||||||
|
max_scale: 3.0
|
||||||
|
min_scale: 1.5
|
||||||
146
scripts/sampling/configs/svd_xt_1_1.yaml
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
ckpt_path: checkpoints/svd_xt_1_1.safetensors
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||||
|
params:
|
||||||
|
adm_in_channels: 768
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
merge_strategy: learned_with_images
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: cond_frames_without_noise
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
params:
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: fps_id
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: motion_bucket_id
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
disable_encoder_autocast: True
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
is_ae: True
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
- input_key: cond_aug
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.Encoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
|
||||||
|
params:
|
||||||
|
attn_type: vanilla
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||||
|
params:
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||||
|
params:
|
||||||
|
sigma_max: 700.0
|
||||||
|
|
||||||
|
guider_config:
|
||||||
|
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
||||||
|
params:
|
||||||
|
max_scale: 3.0
|
||||||
|
min_scale: 1.5
|
||||||
129
scripts/sampling/configs/svd_xt_image_decoder.yaml
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
ckpt_path: checkpoints/svd_xt_image_decoder.safetensors
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.VideoUNet
|
||||||
|
params:
|
||||||
|
adm_in_channels: 768
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
merge_strategy: learned_with_images
|
||||||
|
video_kernel_size: [3, 1, 1]
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: cond_frames_without_noise
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
params:
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: fps_id
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: motion_bucket_id
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
disable_encoder_autocast: True
|
||||||
|
n_cond_frames: 1
|
||||||
|
n_copies: 1
|
||||||
|
is_ae: True
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
- input_key: cond_aug
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||||
|
params:
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||||
|
params:
|
||||||
|
sigma_max: 700.0
|
||||||
|
|
||||||
|
guider_config:
|
||||||
|
target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
||||||
|
params:
|
||||||
|
max_scale: 3.0
|
||||||
|
min_scale: 1.5
|
||||||
350
scripts/sampling/simple_video_sample.py
Normal file
@@ -0,0 +1,350 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from glob import glob
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
|
||||||
|
import cv2
|
||||||
|
import imageio
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from fire import Fire
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from rembg import remove
|
||||||
|
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||||
|
from sgm.inference.helpers import embed_watermark
|
||||||
|
from sgm.util import default, instantiate_from_config
|
||||||
|
from torchvision.transforms import ToTensor
|
||||||
|
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
|
||||||
|
num_frames: Optional[int] = None, # 21 for SV3D
|
||||||
|
num_steps: Optional[int] = None,
|
||||||
|
version: str = "svd",
|
||||||
|
fps_id: int = 6,
|
||||||
|
motion_bucket_id: int = 127,
|
||||||
|
cond_aug: float = 0.02,
|
||||||
|
seed: int = 23,
|
||||||
|
decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||||
|
device: str = "cuda",
|
||||||
|
output_folder: Optional[str] = None,
|
||||||
|
elevations_deg: Optional[float | List[float]] = 10.0, # For SV3D
|
||||||
|
azimuths_deg: Optional[List[float]] = None, # For SV3D
|
||||||
|
image_frame_ratio: Optional[float] = None,
|
||||||
|
verbose: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
|
||||||
|
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if version == "svd":
|
||||||
|
num_frames = default(num_frames, 14)
|
||||||
|
num_steps = default(num_steps, 25)
|
||||||
|
output_folder = default(output_folder, "outputs/simple_video_sample/svd/")
|
||||||
|
model_config = "scripts/sampling/configs/svd.yaml"
|
||||||
|
elif version == "svd_xt":
|
||||||
|
num_frames = default(num_frames, 25)
|
||||||
|
num_steps = default(num_steps, 30)
|
||||||
|
output_folder = default(output_folder, "outputs/simple_video_sample/svd_xt/")
|
||||||
|
model_config = "scripts/sampling/configs/svd_xt.yaml"
|
||||||
|
elif version == "svd_image_decoder":
|
||||||
|
num_frames = default(num_frames, 14)
|
||||||
|
num_steps = default(num_steps, 25)
|
||||||
|
output_folder = default(
|
||||||
|
output_folder, "outputs/simple_video_sample/svd_image_decoder/"
|
||||||
|
)
|
||||||
|
model_config = "scripts/sampling/configs/svd_image_decoder.yaml"
|
||||||
|
elif version == "svd_xt_image_decoder":
|
||||||
|
num_frames = default(num_frames, 25)
|
||||||
|
num_steps = default(num_steps, 30)
|
||||||
|
output_folder = default(
|
||||||
|
output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/"
|
||||||
|
)
|
||||||
|
model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml"
|
||||||
|
elif version == "sv3d_u":
|
||||||
|
num_frames = 21
|
||||||
|
num_steps = default(num_steps, 50)
|
||||||
|
output_folder = default(output_folder, "outputs/simple_video_sample/sv3d_u/")
|
||||||
|
model_config = "scripts/sampling/configs/sv3d_u.yaml"
|
||||||
|
cond_aug = 1e-5
|
||||||
|
elif version == "sv3d_p":
|
||||||
|
num_frames = 21
|
||||||
|
num_steps = default(num_steps, 50)
|
||||||
|
output_folder = default(output_folder, "outputs/simple_video_sample/sv3d_p/")
|
||||||
|
model_config = "scripts/sampling/configs/sv3d_p.yaml"
|
||||||
|
cond_aug = 1e-5
|
||||||
|
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
||||||
|
elevations_deg = [elevations_deg] * num_frames
|
||||||
|
assert (
|
||||||
|
len(elevations_deg) == num_frames
|
||||||
|
), f"Please provide 1 value, or a list of {num_frames} values for elevations_deg! Given {len(elevations_deg)}"
|
||||||
|
polars_rad = [np.deg2rad(90 - e) for e in elevations_deg]
|
||||||
|
if azimuths_deg is None:
|
||||||
|
azimuths_deg = np.linspace(0, 360, num_frames + 1)[1:] % 360
|
||||||
|
assert (
|
||||||
|
len(azimuths_deg) == num_frames
|
||||||
|
), f"Please provide a list of {num_frames} values for azimuths_deg! Given {len(azimuths_deg)}"
|
||||||
|
azimuths_rad = [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
||||||
|
azimuths_rad[:-1].sort()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Version {version} does not exist.")
|
||||||
|
|
||||||
|
model, filter = load_model(
|
||||||
|
model_config,
|
||||||
|
device,
|
||||||
|
num_frames,
|
||||||
|
num_steps,
|
||||||
|
verbose,
|
||||||
|
)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
path = Path(input_path)
|
||||||
|
all_img_paths = []
|
||||||
|
if path.is_file():
|
||||||
|
if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
|
||||||
|
all_img_paths = [input_path]
|
||||||
|
else:
|
||||||
|
raise ValueError("Path is not valid image file.")
|
||||||
|
elif path.is_dir():
|
||||||
|
all_img_paths = sorted(
|
||||||
|
[
|
||||||
|
f
|
||||||
|
for f in path.iterdir()
|
||||||
|
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if len(all_img_paths) == 0:
|
||||||
|
raise ValueError("Folder does not contain any images.")
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
for input_img_path in all_img_paths:
|
||||||
|
if "sv3d" in version:
|
||||||
|
image = Image.open(input_img_path)
|
||||||
|
if image.mode == "RGBA":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# remove bg
|
||||||
|
image.thumbnail([768, 768], Image.Resampling.LANCZOS)
|
||||||
|
image = remove(image.convert("RGBA"), alpha_matting=True)
|
||||||
|
|
||||||
|
# resize object in frame
|
||||||
|
image_arr = np.array(image)
|
||||||
|
in_w, in_h = image_arr.shape[:2]
|
||||||
|
ret, mask = cv2.threshold(
|
||||||
|
np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY
|
||||||
|
)
|
||||||
|
x, y, w, h = cv2.boundingRect(mask)
|
||||||
|
max_size = max(w, h)
|
||||||
|
side_len = (
|
||||||
|
int(max_size / image_frame_ratio)
|
||||||
|
if image_frame_ratio is not None
|
||||||
|
else in_w
|
||||||
|
)
|
||||||
|
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
|
||||||
|
center = side_len // 2
|
||||||
|
padded_image[
|
||||||
|
center - h // 2 : center - h // 2 + h,
|
||||||
|
center - w // 2 : center - w // 2 + w,
|
||||||
|
] = image_arr[y : y + h, x : x + w]
|
||||||
|
# resize frame to 576x576
|
||||||
|
rgba = Image.fromarray(padded_image).resize((576, 576), Image.LANCZOS)
|
||||||
|
# white bg
|
||||||
|
rgba_arr = np.array(rgba) / 255.0
|
||||||
|
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
|
||||||
|
input_image = Image.fromarray((rgb * 255).astype(np.uint8))
|
||||||
|
|
||||||
|
else:
|
||||||
|
with Image.open(input_img_path) as image:
|
||||||
|
if image.mode == "RGBA":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
w, h = image.size
|
||||||
|
|
||||||
|
if h % 64 != 0 or w % 64 != 0:
|
||||||
|
width, height = map(lambda x: x - x % 64, (w, h))
|
||||||
|
input_image = input_image.resize((width, height))
|
||||||
|
print(
|
||||||
|
f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
|
||||||
|
)
|
||||||
|
input_image = np.array(image)
|
||||||
|
|
||||||
|
image = ToTensor()(input_image)
|
||||||
|
image = image * 2.0 - 1.0
|
||||||
|
|
||||||
|
image = image.unsqueeze(0).to(device)
|
||||||
|
H, W = image.shape[2:]
|
||||||
|
assert image.shape[1] == 3
|
||||||
|
F = 8
|
||||||
|
C = 4
|
||||||
|
shape = (num_frames, C, H // F, W // F)
|
||||||
|
if (H, W) != (576, 1024) and "sv3d" not in version:
|
||||||
|
print(
|
||||||
|
"WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
|
||||||
|
)
|
||||||
|
if (H, W) != (576, 576) and "sv3d" in version:
|
||||||
|
print(
|
||||||
|
"WARNING: The conditioning frame you provided is not 576x576. This leads to suboptimal performance as model was only trained on 576x576."
|
||||||
|
)
|
||||||
|
if motion_bucket_id > 255:
|
||||||
|
print(
|
||||||
|
"WARNING: High motion bucket! This may lead to suboptimal performance."
|
||||||
|
)
|
||||||
|
|
||||||
|
if fps_id < 5:
|
||||||
|
print("WARNING: Small fps value! This may lead to suboptimal performance.")
|
||||||
|
|
||||||
|
if fps_id > 30:
|
||||||
|
print("WARNING: Large fps value! This may lead to suboptimal performance.")
|
||||||
|
|
||||||
|
value_dict = {}
|
||||||
|
value_dict["cond_frames_without_noise"] = image
|
||||||
|
value_dict["motion_bucket_id"] = motion_bucket_id
|
||||||
|
value_dict["fps_id"] = fps_id
|
||||||
|
value_dict["cond_aug"] = cond_aug
|
||||||
|
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
|
||||||
|
if "sv3d_p" in version:
|
||||||
|
value_dict["polars_rad"] = polars_rad
|
||||||
|
value_dict["azimuths_rad"] = azimuths_rad
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with torch.autocast(device):
|
||||||
|
batch, batch_uc = get_batch(
|
||||||
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
|
value_dict,
|
||||||
|
[1, num_frames],
|
||||||
|
T=num_frames,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||||
|
batch,
|
||||||
|
batch_uc=batch_uc,
|
||||||
|
force_uc_zero_embeddings=[
|
||||||
|
"cond_frames",
|
||||||
|
"cond_frames_without_noise",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for k in ["crossattn", "concat"]:
|
||||||
|
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
|
||||||
|
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
|
||||||
|
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
|
||||||
|
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
|
||||||
|
|
||||||
|
randn = torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
additional_model_inputs = {}
|
||||||
|
additional_model_inputs["image_only_indicator"] = torch.zeros(
|
||||||
|
2, num_frames
|
||||||
|
).to(device)
|
||||||
|
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
|
||||||
|
|
||||||
|
def denoiser(input, sigma, c):
|
||||||
|
return model.denoiser(
|
||||||
|
model.model, input, sigma, c, **additional_model_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
|
||||||
|
model.en_and_decode_n_samples_a_time = decoding_t
|
||||||
|
samples_x = model.decode_first_stage(samples_z)
|
||||||
|
if "sv3d" in version:
|
||||||
|
samples_x[-1:] = value_dict["cond_frames_without_noise"]
|
||||||
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
|
||||||
|
|
||||||
|
imageio.imwrite(
|
||||||
|
os.path.join(output_folder, f"{base_count:06d}.jpg"), input_image
|
||||||
|
)
|
||||||
|
|
||||||
|
samples = embed_watermark(samples)
|
||||||
|
samples = filter(samples)
|
||||||
|
vid = (
|
||||||
|
(rearrange(samples, "t c h w -> t h w c") * 255)
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
.astype(np.uint8)
|
||||||
|
)
|
||||||
|
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
|
||||||
|
imageio.mimwrite(video_path, vid)
|
||||||
|
|
||||||
|
|
||||||
|
def get_unique_embedder_keys_from_conditioner(conditioner):
|
||||||
|
return list(set([x.input_key for x in conditioner.embedders]))
|
||||||
|
|
||||||
|
|
||||||
|
def get_batch(keys, value_dict, N, T, device):
|
||||||
|
batch = {}
|
||||||
|
batch_uc = {}
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if key == "fps_id":
|
||||||
|
batch[key] = (
|
||||||
|
torch.tensor([value_dict["fps_id"]])
|
||||||
|
.to(device)
|
||||||
|
.repeat(int(math.prod(N)))
|
||||||
|
)
|
||||||
|
elif key == "motion_bucket_id":
|
||||||
|
batch[key] = (
|
||||||
|
torch.tensor([value_dict["motion_bucket_id"]])
|
||||||
|
.to(device)
|
||||||
|
.repeat(int(math.prod(N)))
|
||||||
|
)
|
||||||
|
elif key == "cond_aug":
|
||||||
|
batch[key] = repeat(
|
||||||
|
torch.tensor([value_dict["cond_aug"]]).to(device),
|
||||||
|
"1 -> b",
|
||||||
|
b=math.prod(N),
|
||||||
|
)
|
||||||
|
elif key == "cond_frames" or key == "cond_frames_without_noise":
|
||||||
|
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0])
|
||||||
|
elif key == "polars_rad" or key == "azimuths_rad":
|
||||||
|
batch[key] = torch.tensor(value_dict[key]).to(device).repeat(N[0])
|
||||||
|
else:
|
||||||
|
batch[key] = value_dict[key]
|
||||||
|
|
||||||
|
if T is not None:
|
||||||
|
batch["num_video_frames"] = T
|
||||||
|
|
||||||
|
for key in batch.keys():
|
||||||
|
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
||||||
|
batch_uc[key] = torch.clone(batch[key])
|
||||||
|
return batch, batch_uc
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
config: str,
|
||||||
|
device: str,
|
||||||
|
num_frames: int,
|
||||||
|
num_steps: int,
|
||||||
|
verbose: bool = False,
|
||||||
|
):
|
||||||
|
config = OmegaConf.load(config)
|
||||||
|
if device == "cuda":
|
||||||
|
config.model.params.conditioner_config.params.emb_models[
|
||||||
|
0
|
||||||
|
].params.open_clip_embedding_config.params.init_device = device
|
||||||
|
|
||||||
|
config.model.params.sampler_config.params.verbose = verbose
|
||||||
|
config.model.params.sampler_config.params.num_steps = num_steps
|
||||||
|
config.model.params.sampler_config.params.guider_config.params.num_frames = (
|
||||||
|
num_frames
|
||||||
|
)
|
||||||
|
if device == "cuda":
|
||||||
|
with torch.device(device):
|
||||||
|
model = instantiate_from_config(config.model).to(device).eval()
|
||||||
|
else:
|
||||||
|
model = instantiate_from_config(config.model).to(device).eval()
|
||||||
|
|
||||||
|
filter = DeepFloydDataFiltering(verbose=False, device=device)
|
||||||
|
return model, filter
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Fire(sample)
|
||||||
259
scripts/sampling/simple_video_sample_4d.py
Executable file
@@ -0,0 +1,259 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from glob import glob
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from fire import Fire
|
||||||
|
|
||||||
|
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
|
||||||
|
from scripts.demo.sv4d_helpers import (
|
||||||
|
decode_latents,
|
||||||
|
load_model,
|
||||||
|
initial_model_load,
|
||||||
|
read_video,
|
||||||
|
run_img2vid,
|
||||||
|
prepare_sampling,
|
||||||
|
prepare_inputs,
|
||||||
|
do_sample_per_step,
|
||||||
|
sample_sv3d,
|
||||||
|
save_video,
|
||||||
|
preprocess_video,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
input_path: str = "assets/sv4d_videos/test_video1.mp4", # Can either be image file or folder with image files
|
||||||
|
output_folder: Optional[str] = "outputs/sv4d",
|
||||||
|
num_steps: Optional[int] = 20,
|
||||||
|
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
|
||||||
|
img_size: int = 576, # image resolution
|
||||||
|
fps_id: int = 6,
|
||||||
|
motion_bucket_id: int = 127,
|
||||||
|
cond_aug: float = 1e-5,
|
||||||
|
seed: int = 23,
|
||||||
|
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
|
||||||
|
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||||
|
device: str = "cuda",
|
||||||
|
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
|
||||||
|
azimuths_deg: Optional[List[float]] = None,
|
||||||
|
image_frame_ratio: Optional[float] = 0.917,
|
||||||
|
verbose: Optional[bool] = False,
|
||||||
|
remove_bg: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
|
||||||
|
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`.
|
||||||
|
"""
|
||||||
|
# Set model config
|
||||||
|
T = 5 # number of frames per sample
|
||||||
|
V = 8 # number of views per sample
|
||||||
|
F = 8 # vae factor to downsize image->latent
|
||||||
|
C = 4
|
||||||
|
H, W = img_size, img_size
|
||||||
|
n_frames = 21 # number of input and output video frames
|
||||||
|
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
|
||||||
|
n_views_sv3d = 21
|
||||||
|
subsampled_views = np.array(
|
||||||
|
[0, 2, 5, 7, 9, 12, 14, 16, 19]
|
||||||
|
) # subsample (V+1=)9 (uniform) views from 21 SV3D views
|
||||||
|
|
||||||
|
model_config = "scripts/sampling/configs/sv4d.yaml"
|
||||||
|
version_dict = {
|
||||||
|
"T": T * V,
|
||||||
|
"H": H,
|
||||||
|
"W": W,
|
||||||
|
"C": C,
|
||||||
|
"f": F,
|
||||||
|
"options": {
|
||||||
|
"discretization": 1,
|
||||||
|
"cfg": 2.0,
|
||||||
|
"num_views": V,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
"sigma_max": 700.0,
|
||||||
|
"rho": 7.0,
|
||||||
|
"guider": 5,
|
||||||
|
"num_steps": num_steps,
|
||||||
|
"force_uc_zero_embeddings": [
|
||||||
|
"cond_frames",
|
||||||
|
"cond_frames_without_noise",
|
||||||
|
"cond_view",
|
||||||
|
"cond_motion",
|
||||||
|
],
|
||||||
|
"additional_guider_kwargs": {
|
||||||
|
"additional_cond_keys": ["cond_view", "cond_motion"]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
|
||||||
|
# Read input video frames i.e. images at view 0
|
||||||
|
print(f"Reading {input_path}")
|
||||||
|
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 11
|
||||||
|
processed_input_path = preprocess_video(
|
||||||
|
input_path,
|
||||||
|
remove_bg=remove_bg,
|
||||||
|
n_frames=n_frames,
|
||||||
|
W=W,
|
||||||
|
H=H,
|
||||||
|
output_folder=output_folder,
|
||||||
|
image_frame_ratio=image_frame_ratio,
|
||||||
|
base_count=base_count,
|
||||||
|
)
|
||||||
|
images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device)
|
||||||
|
|
||||||
|
# Get camera viewpoints
|
||||||
|
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
||||||
|
elevations_deg = [elevations_deg] * n_views_sv3d
|
||||||
|
assert (
|
||||||
|
len(elevations_deg) == n_views_sv3d
|
||||||
|
), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}"
|
||||||
|
if azimuths_deg is None:
|
||||||
|
azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360
|
||||||
|
assert (
|
||||||
|
len(azimuths_deg) == n_views_sv3d
|
||||||
|
), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}"
|
||||||
|
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
|
||||||
|
azimuths_rad = np.array(
|
||||||
|
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sample multi-view images of the first frame using SV3D i.e. images at time 0
|
||||||
|
images_t0 = sample_sv3d(
|
||||||
|
images_v0[0],
|
||||||
|
n_views_sv3d,
|
||||||
|
num_steps,
|
||||||
|
sv3d_version,
|
||||||
|
fps_id,
|
||||||
|
motion_bucket_id,
|
||||||
|
cond_aug,
|
||||||
|
decoding_t,
|
||||||
|
device,
|
||||||
|
polars_rad,
|
||||||
|
azimuths_rad,
|
||||||
|
verbose,
|
||||||
|
)
|
||||||
|
images_t0 = torch.roll(images_t0, 1, 0) # move conditioning image to first frame
|
||||||
|
|
||||||
|
# Initialize image matrix
|
||||||
|
img_matrix = [[None] * n_views for _ in range(n_frames)]
|
||||||
|
for i, v in enumerate(subsampled_views):
|
||||||
|
img_matrix[0][i] = images_t0[v].unsqueeze(0)
|
||||||
|
for t in range(n_frames):
|
||||||
|
img_matrix[t][0] = images_v0[t]
|
||||||
|
|
||||||
|
save_video(
|
||||||
|
os.path.join(output_folder, f"{base_count:06d}_t000.mp4"),
|
||||||
|
img_matrix[0],
|
||||||
|
)
|
||||||
|
# save_video(
|
||||||
|
# os.path.join(output_folder, f"{base_count:06d}_v000.mp4"),
|
||||||
|
# [img_matrix[t][0] for t in range(n_frames)],
|
||||||
|
# )
|
||||||
|
|
||||||
|
# Load SV4D model
|
||||||
|
model, filter = load_model(
|
||||||
|
model_config,
|
||||||
|
device,
|
||||||
|
version_dict["T"],
|
||||||
|
num_steps,
|
||||||
|
verbose,
|
||||||
|
)
|
||||||
|
model = initial_model_load(model)
|
||||||
|
for emb in model.conditioner.embedders:
|
||||||
|
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
|
||||||
|
emb.en_and_decode_n_samples_a_time = encoding_t
|
||||||
|
model.en_and_decode_n_samples_a_time = decoding_t
|
||||||
|
|
||||||
|
# Interleaved sampling for anchor frames
|
||||||
|
t0, v0 = 0, 0
|
||||||
|
frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20]
|
||||||
|
view_indices = np.arange(V) + 1
|
||||||
|
print(f"Sampling anchor frames {frame_indices}")
|
||||||
|
image = img_matrix[t0][v0]
|
||||||
|
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
|
||||||
|
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
|
||||||
|
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||||
|
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||||
|
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||||
|
samples = run_img2vid(
|
||||||
|
version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t
|
||||||
|
)
|
||||||
|
samples = samples.view(T, V, 3, H, W)
|
||||||
|
for i, t in enumerate(frame_indices):
|
||||||
|
for j, v in enumerate(view_indices):
|
||||||
|
if img_matrix[t][v] is None:
|
||||||
|
img_matrix[t][v] = samples[i, j][None] * 2 - 1
|
||||||
|
|
||||||
|
# Dense sampling for the rest
|
||||||
|
print(f"Sampling dense frames:")
|
||||||
|
for t0 in tqdm(np.arange(0, n_frames - 1, T - 1)): # [0, 4, 8, 12, 16]
|
||||||
|
frame_indices = t0 + np.arange(T)
|
||||||
|
print(f"Sampling dense frames {frame_indices}")
|
||||||
|
latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda")
|
||||||
|
|
||||||
|
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||||
|
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||||
|
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||||
|
|
||||||
|
# alternate between forward and backward conditioning
|
||||||
|
forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices = prepare_inputs(
|
||||||
|
frame_indices,
|
||||||
|
img_matrix,
|
||||||
|
v0,
|
||||||
|
view_indices,
|
||||||
|
model,
|
||||||
|
version_dict,
|
||||||
|
seed,
|
||||||
|
polars,
|
||||||
|
azims
|
||||||
|
)
|
||||||
|
|
||||||
|
for step in tqdm(range(num_steps)):
|
||||||
|
if step % 2 == 1:
|
||||||
|
c, uc, additional_model_inputs, sampler = forward_inputs
|
||||||
|
frame_indices = forward_frame_indices
|
||||||
|
else:
|
||||||
|
c, uc, additional_model_inputs, sampler = backward_inputs
|
||||||
|
frame_indices = backward_frame_indices
|
||||||
|
noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1)
|
||||||
|
|
||||||
|
samples = do_sample_per_step(
|
||||||
|
model,
|
||||||
|
sampler,
|
||||||
|
noisy_latents,
|
||||||
|
c,
|
||||||
|
uc,
|
||||||
|
step,
|
||||||
|
additional_model_inputs,
|
||||||
|
)
|
||||||
|
samples = samples.view(T, V, C, H // F, W // F)
|
||||||
|
for i, t in enumerate(frame_indices):
|
||||||
|
for j, v in enumerate(view_indices):
|
||||||
|
latent_matrix[t, v] = samples[i, j]
|
||||||
|
|
||||||
|
img_matrix = decode_latents(model, latent_matrix, img_matrix, frame_indices, view_indices, T)
|
||||||
|
|
||||||
|
# Save output videos
|
||||||
|
for v in view_indices:
|
||||||
|
vid_file = os.path.join(output_folder, f"{base_count:06d}_v{v:03d}.mp4")
|
||||||
|
print(f"Saving {vid_file}")
|
||||||
|
save_video(vid_file, [img_matrix[t][v] for t in range(n_frames)])
|
||||||
|
|
||||||
|
# Save diagonal video
|
||||||
|
diag_frames = [
|
||||||
|
img_matrix[t][(t // (n_frames // n_views)) % n_views] for t in range(n_frames)
|
||||||
|
]
|
||||||
|
vid_file = os.path.join(output_folder, f"{base_count:06d}_diag.mp4")
|
||||||
|
print(f"Saving {vid_file}")
|
||||||
|
save_video(vid_file, diag_frames)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Fire(sample)
|
||||||
235
scripts/sampling/simple_video_sample_4d2.py
Executable file
@@ -0,0 +1,235 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from glob import glob
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from fire import Fire
|
||||||
|
from scripts.demo.sv4d_helpers import (
|
||||||
|
load_model,
|
||||||
|
preprocess_video,
|
||||||
|
read_video,
|
||||||
|
run_img2vid,
|
||||||
|
save_video,
|
||||||
|
)
|
||||||
|
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
|
||||||
|
|
||||||
|
sv4d2_configs = {
|
||||||
|
"sv4d2": {
|
||||||
|
"T": 12, # number of frames per sample
|
||||||
|
"V": 4, # number of views per sample
|
||||||
|
"model_config": "scripts/sampling/configs/sv4d2.yaml",
|
||||||
|
"version_dict": {
|
||||||
|
"T": 12 * 4,
|
||||||
|
"options": {
|
||||||
|
"discretization": 1,
|
||||||
|
"cfg": 2.0,
|
||||||
|
"min_cfg": 2.0,
|
||||||
|
"num_views": 4,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
"sigma_max": 700.0,
|
||||||
|
"rho": 7.0,
|
||||||
|
"guider": 2,
|
||||||
|
"force_uc_zero_embeddings": [
|
||||||
|
"cond_frames",
|
||||||
|
"cond_frames_without_noise",
|
||||||
|
"cond_view",
|
||||||
|
"cond_motion",
|
||||||
|
],
|
||||||
|
"additional_guider_kwargs": {
|
||||||
|
"additional_cond_keys": ["cond_view", "cond_motion"]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"sv4d2_8views": {
|
||||||
|
"T": 5, # number of frames per sample
|
||||||
|
"V": 8, # number of views per sample
|
||||||
|
"model_config": "scripts/sampling/configs/sv4d2_8views.yaml",
|
||||||
|
"version_dict": {
|
||||||
|
"T": 5 * 8,
|
||||||
|
"options": {
|
||||||
|
"discretization": 1,
|
||||||
|
"cfg": 2.5,
|
||||||
|
"min_cfg": 1.5,
|
||||||
|
"num_views": 8,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
"sigma_max": 700.0,
|
||||||
|
"rho": 7.0,
|
||||||
|
"guider": 5,
|
||||||
|
"force_uc_zero_embeddings": [
|
||||||
|
"cond_frames",
|
||||||
|
"cond_frames_without_noise",
|
||||||
|
"cond_view",
|
||||||
|
"cond_motion",
|
||||||
|
],
|
||||||
|
"additional_guider_kwargs": {
|
||||||
|
"additional_cond_keys": ["cond_view", "cond_motion"]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
input_path: str = "assets/sv4d_videos/camel.gif", # Can either be image file or folder with image files
|
||||||
|
model_path: Optional[str] = "checkpoints/sv4d2.safetensors",
|
||||||
|
output_folder: Optional[str] = "outputs",
|
||||||
|
num_steps: Optional[int] = 50,
|
||||||
|
img_size: int = 576, # image resolution
|
||||||
|
n_frames: int = 21, # number of input and output video frames
|
||||||
|
seed: int = 23,
|
||||||
|
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
|
||||||
|
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||||
|
device: str = "cuda",
|
||||||
|
elevations_deg: Optional[List[float]] = 0.0,
|
||||||
|
azimuths_deg: Optional[List[float]] = None,
|
||||||
|
image_frame_ratio: Optional[float] = 0.9,
|
||||||
|
verbose: Optional[bool] = False,
|
||||||
|
remove_bg: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
|
||||||
|
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`.
|
||||||
|
"""
|
||||||
|
# Set model config
|
||||||
|
assert os.path.basename(model_path) in [
|
||||||
|
"sv4d2.safetensors",
|
||||||
|
"sv4d2_8views.safetensors",
|
||||||
|
]
|
||||||
|
sv4d2_model = os.path.splitext(os.path.basename(model_path))[0]
|
||||||
|
config = sv4d2_configs[sv4d2_model]
|
||||||
|
print(sv4d2_model, config)
|
||||||
|
T = config["T"]
|
||||||
|
V = config["V"]
|
||||||
|
model_config = config["model_config"]
|
||||||
|
version_dict = config["version_dict"]
|
||||||
|
F = 8 # vae factor to downsize image->latent
|
||||||
|
C = 4
|
||||||
|
H, W = img_size, img_size
|
||||||
|
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
|
||||||
|
subsampled_views = np.arange(n_views)
|
||||||
|
version_dict["H"] = H
|
||||||
|
version_dict["W"] = W
|
||||||
|
version_dict["C"] = C
|
||||||
|
version_dict["f"] = F
|
||||||
|
version_dict["options"]["num_steps"] = num_steps
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
output_folder = os.path.join(output_folder, sv4d2_model)
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
|
||||||
|
# Read input video frames i.e. images at view 0
|
||||||
|
print(f"Reading {input_path}")
|
||||||
|
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // n_views
|
||||||
|
processed_input_path = preprocess_video(
|
||||||
|
input_path,
|
||||||
|
remove_bg=remove_bg,
|
||||||
|
n_frames=n_frames,
|
||||||
|
W=W,
|
||||||
|
H=H,
|
||||||
|
output_folder=output_folder,
|
||||||
|
image_frame_ratio=image_frame_ratio,
|
||||||
|
base_count=base_count,
|
||||||
|
)
|
||||||
|
images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device)
|
||||||
|
images_t0 = torch.zeros(n_views, 3, H, W).float().to(device)
|
||||||
|
|
||||||
|
# Get camera viewpoints
|
||||||
|
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
||||||
|
elevations_deg = [elevations_deg] * n_views
|
||||||
|
assert (
|
||||||
|
len(elevations_deg) == n_views
|
||||||
|
), f"Please provide 1 value, or a list of {n_views} values for elevations_deg! Given {len(elevations_deg)}"
|
||||||
|
if azimuths_deg is None:
|
||||||
|
# azimuths_deg = np.linspace(0, 360, n_views + 1)[1:] % 360
|
||||||
|
azimuths_deg = (
|
||||||
|
np.array([0, 60, 120, 180, 240])
|
||||||
|
if sv4d2_model == "sv4d2"
|
||||||
|
else np.array([0, 30, 75, 120, 165, 210, 255, 300, 330])
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
len(azimuths_deg) == n_views
|
||||||
|
), f"Please provide a list of {n_views} values for azimuths_deg! Given {len(azimuths_deg)}"
|
||||||
|
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
|
||||||
|
azimuths_rad = np.array(
|
||||||
|
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize image matrix
|
||||||
|
img_matrix = [[None] * n_views for _ in range(n_frames)]
|
||||||
|
for i, v in enumerate(subsampled_views):
|
||||||
|
img_matrix[0][i] = images_t0[v].unsqueeze(0)
|
||||||
|
for t in range(n_frames):
|
||||||
|
img_matrix[t][0] = images_v0[t]
|
||||||
|
|
||||||
|
# Load SV4D++ model
|
||||||
|
model, _ = load_model(
|
||||||
|
model_config,
|
||||||
|
device,
|
||||||
|
version_dict["T"],
|
||||||
|
num_steps,
|
||||||
|
verbose,
|
||||||
|
model_path,
|
||||||
|
)
|
||||||
|
model.en_and_decode_n_samples_a_time = decoding_t
|
||||||
|
for emb in model.conditioner.embedders:
|
||||||
|
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
|
||||||
|
emb.en_and_decode_n_samples_a_time = encoding_t
|
||||||
|
|
||||||
|
# Sampling novel-view videos
|
||||||
|
v0 = 0
|
||||||
|
view_indices = np.arange(V) + 1
|
||||||
|
t0_list = (
|
||||||
|
range(0, n_frames, T-1)
|
||||||
|
if sv4d2_model == "sv4d2"
|
||||||
|
else range(0, n_frames - T + 1, T - 1)
|
||||||
|
)
|
||||||
|
for t0 in tqdm(t0_list):
|
||||||
|
if t0 + T > n_frames:
|
||||||
|
t0 = n_frames - T
|
||||||
|
frame_indices = t0 + np.arange(T)
|
||||||
|
print(f"Sampling frames {frame_indices}")
|
||||||
|
image = img_matrix[t0][v0]
|
||||||
|
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
|
||||||
|
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
|
||||||
|
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||||
|
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||||
|
polars = (polars - polars_rad[v0] + torch.pi / 2) % (torch.pi * 2)
|
||||||
|
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||||
|
cond_mv = False if t0 == 0 else True
|
||||||
|
samples = run_img2vid(
|
||||||
|
version_dict,
|
||||||
|
model,
|
||||||
|
image,
|
||||||
|
seed,
|
||||||
|
polars,
|
||||||
|
azims,
|
||||||
|
cond_motion,
|
||||||
|
cond_view,
|
||||||
|
decoding_t,
|
||||||
|
cond_mv=cond_mv,
|
||||||
|
)
|
||||||
|
samples = samples.view(T, V, 3, H, W)
|
||||||
|
|
||||||
|
for i, t in enumerate(frame_indices):
|
||||||
|
for j, v in enumerate(view_indices):
|
||||||
|
img_matrix[t][v] = samples[i, j][None] * 2 - 1
|
||||||
|
|
||||||
|
# Save output videos
|
||||||
|
for v in view_indices:
|
||||||
|
vid_file = os.path.join(output_folder, f"{base_count:06d}_v{v:03d}.mp4")
|
||||||
|
print(f"Saving {vid_file}")
|
||||||
|
save_video(
|
||||||
|
vid_file,
|
||||||
|
[img_matrix[t][v] for t in range(n_frames) if img_matrix[t][v] is not None],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Fire(sample)
|
||||||
319
scripts/tests/attention.py
Normal file
@@ -0,0 +1,319 @@
|
|||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.benchmark as benchmark
|
||||||
|
from torch.backends.cuda import SDPBackend
|
||||||
|
|
||||||
|
from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_attn():
|
||||||
|
# Lets define a helpful benchmarking function:
|
||||||
|
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
||||||
|
t0 = benchmark.Timer(
|
||||||
|
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
||||||
|
)
|
||||||
|
return t0.blocked_autorange().mean * 1e6
|
||||||
|
|
||||||
|
# Lets define the hyper-parameters of our input
|
||||||
|
batch_size = 32
|
||||||
|
max_sequence_len = 1024
|
||||||
|
num_heads = 32
|
||||||
|
embed_dimension = 32
|
||||||
|
|
||||||
|
dtype = torch.float16
|
||||||
|
|
||||||
|
query = torch.rand(
|
||||||
|
batch_size,
|
||||||
|
num_heads,
|
||||||
|
max_sequence_len,
|
||||||
|
embed_dimension,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
key = torch.rand(
|
||||||
|
batch_size,
|
||||||
|
num_heads,
|
||||||
|
max_sequence_len,
|
||||||
|
embed_dimension,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
value = torch.rand(
|
||||||
|
batch_size,
|
||||||
|
num_heads,
|
||||||
|
max_sequence_len,
|
||||||
|
embed_dimension,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"q/k/v shape:", query.shape, key.shape, value.shape)
|
||||||
|
|
||||||
|
# Lets explore the speed of each of the 3 implementations
|
||||||
|
from torch.backends.cuda import SDPBackend, sdp_kernel
|
||||||
|
|
||||||
|
# Helpful arguments mapper
|
||||||
|
backend_map = {
|
||||||
|
SDPBackend.MATH: {
|
||||||
|
"enable_math": True,
|
||||||
|
"enable_flash": False,
|
||||||
|
"enable_mem_efficient": False,
|
||||||
|
},
|
||||||
|
SDPBackend.FLASH_ATTENTION: {
|
||||||
|
"enable_math": False,
|
||||||
|
"enable_flash": True,
|
||||||
|
"enable_mem_efficient": False,
|
||||||
|
},
|
||||||
|
SDPBackend.EFFICIENT_ATTENTION: {
|
||||||
|
"enable_math": False,
|
||||||
|
"enable_flash": False,
|
||||||
|
"enable_mem_efficient": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
from torch.profiler import ProfilerActivity, profile, record_function
|
||||||
|
|
||||||
|
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
||||||
|
)
|
||||||
|
with profile(
|
||||||
|
activities=activities, record_shapes=False, profile_memory=True
|
||||||
|
) as prof:
|
||||||
|
with record_function("Default detailed stats"):
|
||||||
|
for _ in range(25):
|
||||||
|
o = F.scaled_dot_product_attention(query, key, value)
|
||||||
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
||||||
|
)
|
||||||
|
with sdp_kernel(**backend_map[SDPBackend.MATH]):
|
||||||
|
with profile(
|
||||||
|
activities=activities, record_shapes=False, profile_memory=True
|
||||||
|
) as prof:
|
||||||
|
with record_function("Math implmentation stats"):
|
||||||
|
for _ in range(25):
|
||||||
|
o = F.scaled_dot_product_attention(query, key, value)
|
||||||
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||||
|
|
||||||
|
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
|
||||||
|
try:
|
||||||
|
print(
|
||||||
|
f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
print("FlashAttention is not supported. See warnings for reasons.")
|
||||||
|
with profile(
|
||||||
|
activities=activities, record_shapes=False, profile_memory=True
|
||||||
|
) as prof:
|
||||||
|
with record_function("FlashAttention stats"):
|
||||||
|
for _ in range(25):
|
||||||
|
o = F.scaled_dot_product_attention(query, key, value)
|
||||||
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||||
|
|
||||||
|
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
|
||||||
|
try:
|
||||||
|
print(
|
||||||
|
f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
print("EfficientAttention is not supported. See warnings for reasons.")
|
||||||
|
with profile(
|
||||||
|
activities=activities, record_shapes=False, profile_memory=True
|
||||||
|
) as prof:
|
||||||
|
with record_function("EfficientAttention stats"):
|
||||||
|
for _ in range(25):
|
||||||
|
o = F.scaled_dot_product_attention(query, key, value)
|
||||||
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||||
|
|
||||||
|
|
||||||
|
def run_model(model, x, context):
|
||||||
|
return model(x, context)
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_transformer_blocks():
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
import torch.utils.benchmark as benchmark
|
||||||
|
|
||||||
|
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
||||||
|
t0 = benchmark.Timer(
|
||||||
|
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
||||||
|
)
|
||||||
|
return t0.blocked_autorange().mean * 1e6
|
||||||
|
|
||||||
|
checkpoint = True
|
||||||
|
compile = False
|
||||||
|
|
||||||
|
batch_size = 32
|
||||||
|
h, w = 64, 64
|
||||||
|
context_len = 77
|
||||||
|
embed_dimension = 1024
|
||||||
|
context_dim = 1024
|
||||||
|
d_head = 64
|
||||||
|
|
||||||
|
transformer_depth = 4
|
||||||
|
|
||||||
|
n_heads = embed_dimension // d_head
|
||||||
|
|
||||||
|
dtype = torch.float16
|
||||||
|
|
||||||
|
model_native = SpatialTransformer(
|
||||||
|
embed_dimension,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
context_dim=context_dim,
|
||||||
|
use_linear=True,
|
||||||
|
use_checkpoint=checkpoint,
|
||||||
|
attn_type="softmax",
|
||||||
|
depth=transformer_depth,
|
||||||
|
sdp_backend=SDPBackend.FLASH_ATTENTION,
|
||||||
|
).to(device)
|
||||||
|
model_efficient_attn = SpatialTransformer(
|
||||||
|
embed_dimension,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
context_dim=context_dim,
|
||||||
|
use_linear=True,
|
||||||
|
depth=transformer_depth,
|
||||||
|
use_checkpoint=checkpoint,
|
||||||
|
attn_type="softmax-xformers",
|
||||||
|
).to(device)
|
||||||
|
if not checkpoint and compile:
|
||||||
|
print("compiling models")
|
||||||
|
model_native = torch.compile(model_native)
|
||||||
|
model_efficient_attn = torch.compile(model_efficient_attn)
|
||||||
|
|
||||||
|
x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
|
||||||
|
c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
from torch.profiler import ProfilerActivity, profile, record_function
|
||||||
|
|
||||||
|
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
||||||
|
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
print(
|
||||||
|
f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(75 * "+")
|
||||||
|
print("NATIVE")
|
||||||
|
print(75 * "+")
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
with profile(
|
||||||
|
activities=activities, record_shapes=False, profile_memory=True
|
||||||
|
) as prof:
|
||||||
|
with record_function("NativeAttention stats"):
|
||||||
|
for _ in range(25):
|
||||||
|
model_native(x, c)
|
||||||
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||||
|
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
|
||||||
|
|
||||||
|
print(75 * "+")
|
||||||
|
print("Xformers")
|
||||||
|
print(75 * "+")
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
with profile(
|
||||||
|
activities=activities, record_shapes=False, profile_memory=True
|
||||||
|
) as prof:
|
||||||
|
with record_function("xformers stats"):
|
||||||
|
for _ in range(25):
|
||||||
|
model_efficient_attn(x, c)
|
||||||
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||||
|
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
|
||||||
|
|
||||||
|
|
||||||
|
def test01():
|
||||||
|
# conv1x1 vs linear
|
||||||
|
from sgm.util import count_params
|
||||||
|
|
||||||
|
conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda()
|
||||||
|
print(count_params(conv))
|
||||||
|
linear = torch.nn.Linear(3, 32).cuda()
|
||||||
|
print(count_params(linear))
|
||||||
|
|
||||||
|
print(conv.weight.shape)
|
||||||
|
|
||||||
|
# use same initialization
|
||||||
|
linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
|
||||||
|
linear.bias = torch.nn.Parameter(conv.bias)
|
||||||
|
|
||||||
|
print(linear.weight.shape)
|
||||||
|
|
||||||
|
x = torch.randn(11, 3, 64, 64).cuda()
|
||||||
|
|
||||||
|
xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous()
|
||||||
|
print(xr.shape)
|
||||||
|
out_linear = linear(xr)
|
||||||
|
print(out_linear.mean(), out_linear.shape)
|
||||||
|
|
||||||
|
out_conv = conv(x)
|
||||||
|
print(out_conv.mean(), out_conv.shape)
|
||||||
|
print("done with test01.\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test02():
|
||||||
|
# try cosine flash attention
|
||||||
|
import time
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
print("testing cosine flash attention...")
|
||||||
|
DIM = 1024
|
||||||
|
SEQLEN = 4096
|
||||||
|
BS = 16
|
||||||
|
|
||||||
|
print(" softmax (vanilla) first...")
|
||||||
|
model = BasicTransformerBlock(
|
||||||
|
dim=DIM,
|
||||||
|
n_heads=16,
|
||||||
|
d_head=64,
|
||||||
|
dropout=0.0,
|
||||||
|
context_dim=None,
|
||||||
|
attn_mode="softmax",
|
||||||
|
).cuda()
|
||||||
|
try:
|
||||||
|
x = torch.randn(BS, SEQLEN, DIM).cuda()
|
||||||
|
tic = time.time()
|
||||||
|
y = model(x)
|
||||||
|
toc = time.time()
|
||||||
|
print(y.shape, toc - tic)
|
||||||
|
except RuntimeError as e:
|
||||||
|
# likely oom
|
||||||
|
print(str(e))
|
||||||
|
|
||||||
|
print("\n now flash-cosine...")
|
||||||
|
model = BasicTransformerBlock(
|
||||||
|
dim=DIM,
|
||||||
|
n_heads=16,
|
||||||
|
d_head=64,
|
||||||
|
dropout=0.0,
|
||||||
|
context_dim=None,
|
||||||
|
attn_mode="flash-cosine",
|
||||||
|
).cuda()
|
||||||
|
x = torch.randn(BS, SEQLEN, DIM).cuda()
|
||||||
|
tic = time.time()
|
||||||
|
y = model(x)
|
||||||
|
toc = time.time()
|
||||||
|
print(y.shape, toc - tic)
|
||||||
|
print("done with test02.\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# test01()
|
||||||
|
# test02()
|
||||||
|
# test03()
|
||||||
|
|
||||||
|
# benchmark_attn()
|
||||||
|
benchmark_transformer_blocks()
|
||||||
|
|
||||||
|
print("done.")
|
||||||
0
scripts/util/__init__.py
Normal file
0
scripts/util/detection/__init__.py
Normal file
@@ -1,9 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
|
import clip
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import clip
|
|
||||||
|
|
||||||
RESOURCES_ROOT = "scripts/util/detection/"
|
RESOURCES_ROOT = "scripts/util/detection/"
|
||||||
|
|
||||||
@@ -36,10 +37,13 @@ def clip_process_images(images: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
class DeepFloydDataFiltering(object):
|
class DeepFloydDataFiltering(object):
|
||||||
def __init__(self, verbose: bool = False):
|
def __init__(
|
||||||
|
self, verbose: bool = False, device: torch.device = torch.device("cpu")
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.clip_model, _ = clip.load("ViT-L/14", device="cpu")
|
self._device = None
|
||||||
|
self.clip_model, _ = clip.load("ViT-L/14", device=device)
|
||||||
self.clip_model.eval()
|
self.clip_model.eval()
|
||||||
|
|
||||||
self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
|
self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
|
||||||
@@ -53,7 +57,9 @@ class DeepFloydDataFiltering(object):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def __call__(self, images: torch.Tensor) -> torch.Tensor:
|
def __call__(self, images: torch.Tensor) -> torch.Tensor:
|
||||||
imgs = clip_process_images(images)
|
imgs = clip_process_images(images)
|
||||||
image_features = self.clip_model.encode_image(imgs.to("cpu"))
|
if self._device is None:
|
||||||
|
self._device = next(p for p in self.clip_model.parameters()).device
|
||||||
|
image_features = self.clip_model.encode_image(imgs.to(self._device))
|
||||||
image_features = image_features.detach().cpu().numpy().astype(np.float16)
|
image_features = image_features.detach().cpu().numpy().astype(np.float16)
|
||||||
p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
|
p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
|
||||||
w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)
|
w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)
|
||||||
|
|||||||
13
setup.py
@@ -1,13 +0,0 @@
|
|||||||
from setuptools import find_packages, setup
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name="sgm",
|
|
||||||
version="0.0.1",
|
|
||||||
packages=find_packages(),
|
|
||||||
python_requires=">=3.8",
|
|
||||||
py_modules=["sgm"],
|
|
||||||
description="Stability Generative Models",
|
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
|
||||||
long_description_content_type="text/markdown",
|
|
||||||
url="https://github.com/Stability-AI/generative-models",
|
|
||||||
)
|
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
from .data import StableDataModuleFromConfig
|
|
||||||
from .models import AutoencodingEngine, DiffusionEngine
|
from .models import AutoencodingEngine, DiffusionEngine
|
||||||
from .util import instantiate_from_config
|
from .util import get_configs_path, instantiate_from_config
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import torchvision
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from torchvision import transforms
|
import torchvision
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
class CIFAR10DataDictWrapper(Dataset):
|
class CIFAR10DataDictWrapper(Dataset):
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import torchvision
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from torchvision import transforms
|
import torchvision
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
class MNISTDataDictWrapper(Dataset):
|
class MNISTDataDictWrapper(Dataset):
|
||||||
|
|||||||
363
sgm/inference/api.py
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
import pathlib
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
|
||||||
|
do_sample)
|
||||||
|
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
|
||||||
|
DPMPP2SAncestralSampler,
|
||||||
|
EulerAncestralSampler,
|
||||||
|
EulerEDMSampler,
|
||||||
|
HeunEDMSampler,
|
||||||
|
LinearMultistepSampler)
|
||||||
|
from sgm.util import load_model_from_config
|
||||||
|
|
||||||
|
|
||||||
|
class ModelArchitecture(str, Enum):
|
||||||
|
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
|
||||||
|
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
|
||||||
|
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||||
|
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
|
||||||
|
|
||||||
|
|
||||||
|
class Sampler(str, Enum):
|
||||||
|
EULER_EDM = "EulerEDMSampler"
|
||||||
|
HEUN_EDM = "HeunEDMSampler"
|
||||||
|
EULER_ANCESTRAL = "EulerAncestralSampler"
|
||||||
|
DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
|
||||||
|
DPMPP2M = "DPMPP2MSampler"
|
||||||
|
LINEAR_MULTISTEP = "LinearMultistepSampler"
|
||||||
|
|
||||||
|
|
||||||
|
class Discretization(str, Enum):
|
||||||
|
LEGACY_DDPM = "LegacyDDPMDiscretization"
|
||||||
|
EDM = "EDMDiscretization"
|
||||||
|
|
||||||
|
|
||||||
|
class Guider(str, Enum):
|
||||||
|
VANILLA = "VanillaCFG"
|
||||||
|
IDENTITY = "IdentityGuider"
|
||||||
|
|
||||||
|
|
||||||
|
class Thresholder(str, Enum):
|
||||||
|
NONE = "None"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SamplingParams:
|
||||||
|
width: int = 1024
|
||||||
|
height: int = 1024
|
||||||
|
steps: int = 50
|
||||||
|
sampler: Sampler = Sampler.DPMPP2M
|
||||||
|
discretization: Discretization = Discretization.LEGACY_DDPM
|
||||||
|
guider: Guider = Guider.VANILLA
|
||||||
|
thresholder: Thresholder = Thresholder.NONE
|
||||||
|
scale: float = 6.0
|
||||||
|
aesthetic_score: float = 5.0
|
||||||
|
negative_aesthetic_score: float = 5.0
|
||||||
|
img2img_strength: float = 1.0
|
||||||
|
orig_width: int = 1024
|
||||||
|
orig_height: int = 1024
|
||||||
|
crop_coords_top: int = 0
|
||||||
|
crop_coords_left: int = 0
|
||||||
|
sigma_min: float = 0.0292
|
||||||
|
sigma_max: float = 14.6146
|
||||||
|
rho: float = 3.0
|
||||||
|
s_churn: float = 0.0
|
||||||
|
s_tmin: float = 0.0
|
||||||
|
s_tmax: float = 999.0
|
||||||
|
s_noise: float = 1.0
|
||||||
|
eta: float = 1.0
|
||||||
|
order: int = 4
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SamplingSpec:
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
channels: int
|
||||||
|
factor: int
|
||||||
|
is_legacy: bool
|
||||||
|
config: str
|
||||||
|
ckpt: str
|
||||||
|
is_guided: bool
|
||||||
|
|
||||||
|
|
||||||
|
model_specs = {
|
||||||
|
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
channels=4,
|
||||||
|
factor=8,
|
||||||
|
is_legacy=False,
|
||||||
|
config="sd_xl_base.yaml",
|
||||||
|
ckpt="sd_xl_base_0.9.safetensors",
|
||||||
|
is_guided=True,
|
||||||
|
),
|
||||||
|
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
channels=4,
|
||||||
|
factor=8,
|
||||||
|
is_legacy=True,
|
||||||
|
config="sd_xl_refiner.yaml",
|
||||||
|
ckpt="sd_xl_refiner_0.9.safetensors",
|
||||||
|
is_guided=True,
|
||||||
|
),
|
||||||
|
ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
channels=4,
|
||||||
|
factor=8,
|
||||||
|
is_legacy=False,
|
||||||
|
config="sd_xl_base.yaml",
|
||||||
|
ckpt="sd_xl_base_1.0.safetensors",
|
||||||
|
is_guided=True,
|
||||||
|
),
|
||||||
|
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
|
||||||
|
height=1024,
|
||||||
|
width=1024,
|
||||||
|
channels=4,
|
||||||
|
factor=8,
|
||||||
|
is_legacy=True,
|
||||||
|
config="sd_xl_refiner.yaml",
|
||||||
|
ckpt="sd_xl_refiner_1.0.safetensors",
|
||||||
|
is_guided=True,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SamplingPipeline:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: ModelArchitecture,
|
||||||
|
model_path="checkpoints",
|
||||||
|
config_path="configs/inference",
|
||||||
|
device="cuda",
|
||||||
|
use_fp16=True,
|
||||||
|
) -> None:
|
||||||
|
if model_id not in model_specs:
|
||||||
|
raise ValueError(f"Model {model_id} not supported")
|
||||||
|
self.model_id = model_id
|
||||||
|
self.specs = model_specs[self.model_id]
|
||||||
|
self.config = str(pathlib.Path(config_path, self.specs.config))
|
||||||
|
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
|
||||||
|
self.device = device
|
||||||
|
self.model = self._load_model(device=device, use_fp16=use_fp16)
|
||||||
|
|
||||||
|
def _load_model(self, device="cuda", use_fp16=True):
|
||||||
|
config = OmegaConf.load(self.config)
|
||||||
|
model = load_model_from_config(config, self.ckpt)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model {self.model_id} could not be loaded")
|
||||||
|
model.to(device)
|
||||||
|
if use_fp16:
|
||||||
|
model.conditioner.half()
|
||||||
|
model.model.half()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def text_to_image(
|
||||||
|
self,
|
||||||
|
params: SamplingParams,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
samples: int = 1,
|
||||||
|
return_latents: bool = False,
|
||||||
|
):
|
||||||
|
sampler = get_sampler_config(params)
|
||||||
|
value_dict = asdict(params)
|
||||||
|
value_dict["prompt"] = prompt
|
||||||
|
value_dict["negative_prompt"] = negative_prompt
|
||||||
|
value_dict["target_width"] = params.width
|
||||||
|
value_dict["target_height"] = params.height
|
||||||
|
return do_sample(
|
||||||
|
self.model,
|
||||||
|
sampler,
|
||||||
|
value_dict,
|
||||||
|
samples,
|
||||||
|
params.height,
|
||||||
|
params.width,
|
||||||
|
self.specs.channels,
|
||||||
|
self.specs.factor,
|
||||||
|
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||||
|
return_latents=return_latents,
|
||||||
|
filter=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def image_to_image(
|
||||||
|
self,
|
||||||
|
params: SamplingParams,
|
||||||
|
image,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
samples: int = 1,
|
||||||
|
return_latents: bool = False,
|
||||||
|
):
|
||||||
|
sampler = get_sampler_config(params)
|
||||||
|
|
||||||
|
if params.img2img_strength < 1.0:
|
||||||
|
sampler.discretization = Img2ImgDiscretizationWrapper(
|
||||||
|
sampler.discretization,
|
||||||
|
strength=params.img2img_strength,
|
||||||
|
)
|
||||||
|
height, width = image.shape[2], image.shape[3]
|
||||||
|
value_dict = asdict(params)
|
||||||
|
value_dict["prompt"] = prompt
|
||||||
|
value_dict["negative_prompt"] = negative_prompt
|
||||||
|
value_dict["target_width"] = width
|
||||||
|
value_dict["target_height"] = height
|
||||||
|
return do_img2img(
|
||||||
|
image,
|
||||||
|
self.model,
|
||||||
|
sampler,
|
||||||
|
value_dict,
|
||||||
|
samples,
|
||||||
|
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||||
|
return_latents=return_latents,
|
||||||
|
filter=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def refiner(
|
||||||
|
self,
|
||||||
|
params: SamplingParams,
|
||||||
|
image,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: Optional[str] = None,
|
||||||
|
samples: int = 1,
|
||||||
|
return_latents: bool = False,
|
||||||
|
):
|
||||||
|
sampler = get_sampler_config(params)
|
||||||
|
value_dict = {
|
||||||
|
"orig_width": image.shape[3] * 8,
|
||||||
|
"orig_height": image.shape[2] * 8,
|
||||||
|
"target_width": image.shape[3] * 8,
|
||||||
|
"target_height": image.shape[2] * 8,
|
||||||
|
"prompt": prompt,
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
"crop_coords_top": 0,
|
||||||
|
"crop_coords_left": 0,
|
||||||
|
"aesthetic_score": 6.0,
|
||||||
|
"negative_aesthetic_score": 2.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
return do_img2img(
|
||||||
|
image,
|
||||||
|
self.model,
|
||||||
|
sampler,
|
||||||
|
value_dict,
|
||||||
|
samples,
|
||||||
|
skip_encode=True,
|
||||||
|
return_latents=return_latents,
|
||||||
|
filter=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_guider_config(params: SamplingParams):
|
||||||
|
if params.guider == Guider.IDENTITY:
|
||||||
|
guider_config = {
|
||||||
|
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
||||||
|
}
|
||||||
|
elif params.guider == Guider.VANILLA:
|
||||||
|
scale = params.scale
|
||||||
|
|
||||||
|
thresholder = params.thresholder
|
||||||
|
|
||||||
|
if thresholder == Thresholder.NONE:
|
||||||
|
dyn_thresh_config = {
|
||||||
|
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
guider_config = {
|
||||||
|
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
||||||
|
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return guider_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_discretization_config(params: SamplingParams):
|
||||||
|
if params.discretization == Discretization.LEGACY_DDPM:
|
||||||
|
discretization_config = {
|
||||||
|
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
||||||
|
}
|
||||||
|
elif params.discretization == Discretization.EDM:
|
||||||
|
discretization_config = {
|
||||||
|
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
||||||
|
"params": {
|
||||||
|
"sigma_min": params.sigma_min,
|
||||||
|
"sigma_max": params.sigma_max,
|
||||||
|
"rho": params.rho,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown discretization {params.discretization}")
|
||||||
|
return discretization_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_sampler_config(params: SamplingParams):
|
||||||
|
discretization_config = get_discretization_config(params)
|
||||||
|
guider_config = get_guider_config(params)
|
||||||
|
sampler = None
|
||||||
|
if params.sampler == Sampler.EULER_EDM:
|
||||||
|
return EulerEDMSampler(
|
||||||
|
num_steps=params.steps,
|
||||||
|
discretization_config=discretization_config,
|
||||||
|
guider_config=guider_config,
|
||||||
|
s_churn=params.s_churn,
|
||||||
|
s_tmin=params.s_tmin,
|
||||||
|
s_tmax=params.s_tmax,
|
||||||
|
s_noise=params.s_noise,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
if params.sampler == Sampler.HEUN_EDM:
|
||||||
|
return HeunEDMSampler(
|
||||||
|
num_steps=params.steps,
|
||||||
|
discretization_config=discretization_config,
|
||||||
|
guider_config=guider_config,
|
||||||
|
s_churn=params.s_churn,
|
||||||
|
s_tmin=params.s_tmin,
|
||||||
|
s_tmax=params.s_tmax,
|
||||||
|
s_noise=params.s_noise,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
if params.sampler == Sampler.EULER_ANCESTRAL:
|
||||||
|
return EulerAncestralSampler(
|
||||||
|
num_steps=params.steps,
|
||||||
|
discretization_config=discretization_config,
|
||||||
|
guider_config=guider_config,
|
||||||
|
eta=params.eta,
|
||||||
|
s_noise=params.s_noise,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
|
||||||
|
return DPMPP2SAncestralSampler(
|
||||||
|
num_steps=params.steps,
|
||||||
|
discretization_config=discretization_config,
|
||||||
|
guider_config=guider_config,
|
||||||
|
eta=params.eta,
|
||||||
|
s_noise=params.s_noise,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
if params.sampler == Sampler.DPMPP2M:
|
||||||
|
return DPMPP2MSampler(
|
||||||
|
num_steps=params.steps,
|
||||||
|
discretization_config=discretization_config,
|
||||||
|
guider_config=guider_config,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
if params.sampler == Sampler.LINEAR_MULTISTEP:
|
||||||
|
return LinearMultistepSampler(
|
||||||
|
num_steps=params.steps,
|
||||||
|
discretization_config=discretization_config,
|
||||||
|
guider_config=guider_config,
|
||||||
|
order=params.order,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"unknown sampler {params.sampler}!")
|
||||||
305
sgm/inference/helpers.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from imwatermark import WatermarkEncoder
|
||||||
|
from omegaconf import ListConfig
|
||||||
|
from PIL import Image
|
||||||
|
from torch import autocast
|
||||||
|
|
||||||
|
from sgm.util import append_dims
|
||||||
|
|
||||||
|
|
||||||
|
class WatermarkEmbedder:
|
||||||
|
def __init__(self, watermark):
|
||||||
|
self.watermark = watermark
|
||||||
|
self.num_bits = len(WATERMARK_BITS)
|
||||||
|
self.encoder = WatermarkEncoder()
|
||||||
|
self.encoder.set_watermark("bits", self.watermark)
|
||||||
|
|
||||||
|
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Adds a predefined watermark to the input image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: ([N,] B, RGB, H, W) in range [0, 1]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
same as input but watermarked
|
||||||
|
"""
|
||||||
|
squeeze = len(image.shape) == 4
|
||||||
|
if squeeze:
|
||||||
|
image = image[None, ...]
|
||||||
|
n = image.shape[0]
|
||||||
|
image_np = rearrange(
|
||||||
|
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
||||||
|
).numpy()[:, :, :, ::-1]
|
||||||
|
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||||
|
# watermarking libary expects input as cv2 BGR format
|
||||||
|
for k in range(image_np.shape[0]):
|
||||||
|
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||||
|
image = torch.from_numpy(
|
||||||
|
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
|
||||||
|
).to(image.device)
|
||||||
|
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
||||||
|
if squeeze:
|
||||||
|
image = image[0]
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
# A fixed 48-bit message that was choosen at random
|
||||||
|
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
||||||
|
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
||||||
|
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
||||||
|
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
||||||
|
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
||||||
|
|
||||||
|
|
||||||
|
def get_unique_embedder_keys_from_conditioner(conditioner):
|
||||||
|
return list({x.input_key for x in conditioner.embedders})
|
||||||
|
|
||||||
|
|
||||||
|
def perform_save_locally(save_path, samples):
|
||||||
|
os.makedirs(os.path.join(save_path), exist_ok=True)
|
||||||
|
base_count = len(os.listdir(os.path.join(save_path)))
|
||||||
|
samples = embed_watermark(samples)
|
||||||
|
for sample in samples:
|
||||||
|
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
||||||
|
Image.fromarray(sample.astype(np.uint8)).save(
|
||||||
|
os.path.join(save_path, f"{base_count:09}.png")
|
||||||
|
)
|
||||||
|
base_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
class Img2ImgDiscretizationWrapper:
|
||||||
|
"""
|
||||||
|
wraps a discretizer, and prunes the sigmas
|
||||||
|
params:
|
||||||
|
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, discretization, strength: float = 1.0):
|
||||||
|
self.discretization = discretization
|
||||||
|
self.strength = strength
|
||||||
|
assert 0.0 <= self.strength <= 1.0
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
# sigmas start large first, and decrease then
|
||||||
|
sigmas = self.discretization(*args, **kwargs)
|
||||||
|
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||||
|
sigmas = torch.flip(sigmas, (0,))
|
||||||
|
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
||||||
|
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
||||||
|
sigmas = torch.flip(sigmas, (0,))
|
||||||
|
print(f"sigmas after pruning: ", sigmas)
|
||||||
|
return sigmas
|
||||||
|
|
||||||
|
|
||||||
|
def do_sample(
|
||||||
|
model,
|
||||||
|
sampler,
|
||||||
|
value_dict,
|
||||||
|
num_samples,
|
||||||
|
H,
|
||||||
|
W,
|
||||||
|
C,
|
||||||
|
F,
|
||||||
|
force_uc_zero_embeddings: Optional[List] = None,
|
||||||
|
batch2model_input: Optional[List] = None,
|
||||||
|
return_latents=False,
|
||||||
|
filter=None,
|
||||||
|
device="cuda",
|
||||||
|
):
|
||||||
|
if force_uc_zero_embeddings is None:
|
||||||
|
force_uc_zero_embeddings = []
|
||||||
|
if batch2model_input is None:
|
||||||
|
batch2model_input = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with autocast(device) as precision_scope:
|
||||||
|
with model.ema_scope():
|
||||||
|
num_samples = [num_samples]
|
||||||
|
batch, batch_uc = get_batch(
|
||||||
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
|
value_dict,
|
||||||
|
num_samples,
|
||||||
|
)
|
||||||
|
for key in batch:
|
||||||
|
if isinstance(batch[key], torch.Tensor):
|
||||||
|
print(key, batch[key].shape)
|
||||||
|
elif isinstance(batch[key], list):
|
||||||
|
print(key, [len(l) for l in batch[key]])
|
||||||
|
else:
|
||||||
|
print(key, batch[key])
|
||||||
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||||
|
batch,
|
||||||
|
batch_uc=batch_uc,
|
||||||
|
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
for k in c:
|
||||||
|
if not k == "crossattn":
|
||||||
|
c[k], uc[k] = map(
|
||||||
|
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
|
||||||
|
)
|
||||||
|
|
||||||
|
additional_model_inputs = {}
|
||||||
|
for k in batch2model_input:
|
||||||
|
additional_model_inputs[k] = batch[k]
|
||||||
|
|
||||||
|
shape = (math.prod(num_samples), C, H // F, W // F)
|
||||||
|
randn = torch.randn(shape).to(device)
|
||||||
|
|
||||||
|
def denoiser(input, sigma, c):
|
||||||
|
return model.denoiser(
|
||||||
|
model.model, input, sigma, c, **additional_model_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||||
|
samples_x = model.decode_first_stage(samples_z)
|
||||||
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
if filter is not None:
|
||||||
|
samples = filter(samples)
|
||||||
|
|
||||||
|
if return_latents:
|
||||||
|
return samples, samples_z
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||||
|
# Hardcoded demo setups; might undergo some changes in the future
|
||||||
|
|
||||||
|
batch = {}
|
||||||
|
batch_uc = {}
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if key == "txt":
|
||||||
|
batch["txt"] = (
|
||||||
|
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
||||||
|
.reshape(N)
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
batch_uc["txt"] = (
|
||||||
|
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
||||||
|
.reshape(N)
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
elif key == "original_size_as_tuple":
|
||||||
|
batch["original_size_as_tuple"] = (
|
||||||
|
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
||||||
|
.to(device)
|
||||||
|
.repeat(*N, 1)
|
||||||
|
)
|
||||||
|
elif key == "crop_coords_top_left":
|
||||||
|
batch["crop_coords_top_left"] = (
|
||||||
|
torch.tensor(
|
||||||
|
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
||||||
|
)
|
||||||
|
.to(device)
|
||||||
|
.repeat(*N, 1)
|
||||||
|
)
|
||||||
|
elif key == "aesthetic_score":
|
||||||
|
batch["aesthetic_score"] = (
|
||||||
|
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||||
|
)
|
||||||
|
batch_uc["aesthetic_score"] = (
|
||||||
|
torch.tensor([value_dict["negative_aesthetic_score"]])
|
||||||
|
.to(device)
|
||||||
|
.repeat(*N, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif key == "target_size_as_tuple":
|
||||||
|
batch["target_size_as_tuple"] = (
|
||||||
|
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
||||||
|
.to(device)
|
||||||
|
.repeat(*N, 1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch[key] = value_dict[key]
|
||||||
|
|
||||||
|
for key in batch.keys():
|
||||||
|
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
||||||
|
batch_uc[key] = torch.clone(batch[key])
|
||||||
|
return batch, batch_uc
|
||||||
|
|
||||||
|
|
||||||
|
def get_input_image_tensor(image: Image.Image, device="cuda"):
|
||||||
|
w, h = image.size
|
||||||
|
print(f"loaded input image of size ({w}, {h})")
|
||||||
|
width, height = map(
|
||||||
|
lambda x: x - x % 64, (w, h)
|
||||||
|
) # resize to integer multiple of 64
|
||||||
|
image = image.resize((width, height))
|
||||||
|
image_array = np.array(image.convert("RGB"))
|
||||||
|
image_array = image_array[None].transpose(0, 3, 1, 2)
|
||||||
|
image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
|
||||||
|
return image_tensor.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def do_img2img(
|
||||||
|
img,
|
||||||
|
model,
|
||||||
|
sampler,
|
||||||
|
value_dict,
|
||||||
|
num_samples,
|
||||||
|
force_uc_zero_embeddings=[],
|
||||||
|
additional_kwargs={},
|
||||||
|
offset_noise_level: float = 0.0,
|
||||||
|
return_latents=False,
|
||||||
|
skip_encode=False,
|
||||||
|
filter=None,
|
||||||
|
device="cuda",
|
||||||
|
):
|
||||||
|
with torch.no_grad():
|
||||||
|
with autocast(device) as precision_scope:
|
||||||
|
with model.ema_scope():
|
||||||
|
batch, batch_uc = get_batch(
|
||||||
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
|
value_dict,
|
||||||
|
[num_samples],
|
||||||
|
)
|
||||||
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||||
|
batch,
|
||||||
|
batch_uc=batch_uc,
|
||||||
|
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
for k in c:
|
||||||
|
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
|
||||||
|
|
||||||
|
for k in additional_kwargs:
|
||||||
|
c[k] = uc[k] = additional_kwargs[k]
|
||||||
|
if skip_encode:
|
||||||
|
z = img
|
||||||
|
else:
|
||||||
|
z = model.encode_first_stage(img)
|
||||||
|
noise = torch.randn_like(z)
|
||||||
|
sigmas = sampler.discretization(sampler.num_steps)
|
||||||
|
sigma = sigmas[0].to(z.device)
|
||||||
|
|
||||||
|
if offset_noise_level > 0.0:
|
||||||
|
noise = noise + offset_noise_level * append_dims(
|
||||||
|
torch.randn(z.shape[0], device=z.device), z.ndim
|
||||||
|
)
|
||||||
|
noised_z = z + noise * append_dims(sigma, z.ndim)
|
||||||
|
noised_z = noised_z / torch.sqrt(
|
||||||
|
1.0 + sigmas[0] ** 2.0
|
||||||
|
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
||||||
|
|
||||||
|
def denoiser(x, sigma, c):
|
||||||
|
return model.denoiser(model.model, x, sigma, c)
|
||||||
|
|
||||||
|
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
||||||
|
samples_x = model.decode_first_stage(samples_z)
|
||||||
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
if filter is not None:
|
||||||
|
samples = filter(samples)
|
||||||
|
|
||||||
|
if return_latents:
|
||||||
|
return samples, samples_z
|
||||||
|
return samples
|
||||||
@@ -1,18 +1,22 @@
|
|||||||
|
import logging
|
||||||
|
import math
|
||||||
import re
|
import re
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import ListConfig
|
import torch.nn as nn
|
||||||
|
from einops import rearrange
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from safetensors.torch import load_file as load_safetensors
|
|
||||||
|
|
||||||
from ..modules.diffusionmodules.model import Decoder, Encoder
|
from ..modules.autoencoding.regularizers import AbstractRegularizer
|
||||||
from ..modules.distributions.distributions import DiagonalGaussianDistribution
|
|
||||||
from ..modules.ema import LitEma
|
from ..modules.ema import LitEma
|
||||||
from ..util import default, get_obj_from_str, instantiate_from_config
|
from ..util import (default, get_nested_attribute, get_obj_from_str,
|
||||||
|
instantiate_from_config)
|
||||||
|
|
||||||
|
logpy = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AbstractAutoencoder(pl.LightningModule):
|
class AbstractAutoencoder(pl.LightningModule):
|
||||||
@@ -27,10 +31,9 @@ class AbstractAutoencoder(pl.LightningModule):
|
|||||||
ema_decay: Union[None, float] = None,
|
ema_decay: Union[None, float] = None,
|
||||||
monitor: Union[None, str] = None,
|
monitor: Union[None, str] = None,
|
||||||
input_key: str = "jpg",
|
input_key: str = "jpg",
|
||||||
ckpt_path: Union[None, str] = None,
|
|
||||||
ignore_keys: Union[Tuple, list, ListConfig] = (),
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.input_key = input_key
|
self.input_key = input_key
|
||||||
self.use_ema = ema_decay is not None
|
self.use_ema = ema_decay is not None
|
||||||
if monitor is not None:
|
if monitor is not None:
|
||||||
@@ -38,38 +41,21 @@ class AbstractAutoencoder(pl.LightningModule):
|
|||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema = LitEma(self, decay=ema_decay)
|
self.model_ema = LitEma(self, decay=ema_decay)
|
||||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
if ckpt_path is not None:
|
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
|
||||||
|
|
||||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||||
self.automatic_optimization = False
|
self.automatic_optimization = False
|
||||||
|
|
||||||
def init_from_ckpt(
|
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
||||||
self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
|
if ckpt is None:
|
||||||
) -> None:
|
return
|
||||||
if path.endswith("ckpt"):
|
if isinstance(ckpt, str):
|
||||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
ckpt = {
|
||||||
elif path.endswith("safetensors"):
|
"target": "sgm.modules.checkpoint.CheckpointEngine",
|
||||||
sd = load_safetensors(path)
|
"params": {"ckpt_path": ckpt},
|
||||||
else:
|
}
|
||||||
raise NotImplementedError
|
engine = instantiate_from_config(ckpt)
|
||||||
|
engine(self)
|
||||||
keys = list(sd.keys())
|
|
||||||
for k in keys:
|
|
||||||
for ik in ignore_keys:
|
|
||||||
if re.match(ik, k):
|
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
|
||||||
del sd[k]
|
|
||||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
|
||||||
print(
|
|
||||||
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
|
||||||
)
|
|
||||||
if len(missing) > 0:
|
|
||||||
print(f"Missing Keys: {missing}")
|
|
||||||
if len(unexpected) > 0:
|
|
||||||
print(f"Unexpected Keys: {unexpected}")
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_input(self, batch) -> Any:
|
def get_input(self, batch) -> Any:
|
||||||
@@ -86,14 +72,14 @@ class AbstractAutoencoder(pl.LightningModule):
|
|||||||
self.model_ema.store(self.parameters())
|
self.model_ema.store(self.parameters())
|
||||||
self.model_ema.copy_to(self)
|
self.model_ema.copy_to(self)
|
||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Switched to EMA weights")
|
logpy.info(f"{context}: Switched to EMA weights")
|
||||||
try:
|
try:
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema.restore(self.parameters())
|
self.model_ema.restore(self.parameters())
|
||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
logpy.info(f"{context}: Restored training weights")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def encode(self, *args, **kwargs) -> torch.Tensor:
|
def encode(self, *args, **kwargs) -> torch.Tensor:
|
||||||
@@ -104,7 +90,7 @@ class AbstractAutoencoder(pl.LightningModule):
|
|||||||
raise NotImplementedError("decode()-method of abstract base class called")
|
raise NotImplementedError("decode()-method of abstract base class called")
|
||||||
|
|
||||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||||
print(f"loading >>> {cfg['target']} <<< optimizer from config")
|
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||||
return get_obj_from_str(cfg["target"])(
|
return get_obj_from_str(cfg["target"])(
|
||||||
params, lr=lr, **cfg.get("params", dict())
|
params, lr=lr, **cfg.get("params", dict())
|
||||||
)
|
)
|
||||||
@@ -129,196 +115,435 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
regularizer_config: Dict,
|
regularizer_config: Dict,
|
||||||
optimizer_config: Union[Dict, None] = None,
|
optimizer_config: Union[Dict, None] = None,
|
||||||
lr_g_factor: float = 1.0,
|
lr_g_factor: float = 1.0,
|
||||||
|
trainable_ae_params: Optional[List[List[str]]] = None,
|
||||||
|
ae_optimizer_args: Optional[List[dict]] = None,
|
||||||
|
trainable_disc_params: Optional[List[List[str]]] = None,
|
||||||
|
disc_optimizer_args: Optional[List[dict]] = None,
|
||||||
|
disc_start_iter: int = 0,
|
||||||
|
diff_boost_factor: float = 3.0,
|
||||||
|
ckpt_engine: Union[None, str, dict] = None,
|
||||||
|
ckpt_path: Optional[str] = None,
|
||||||
|
additional_decode_keys: Optional[List[str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
# todo: add options to freeze encoder/decoder
|
self.automatic_optimization = False # pytorch lightning
|
||||||
self.encoder = instantiate_from_config(encoder_config)
|
|
||||||
self.decoder = instantiate_from_config(decoder_config)
|
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
||||||
self.loss = instantiate_from_config(loss_config)
|
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
||||||
self.regularization = instantiate_from_config(regularizer_config)
|
self.loss: torch.nn.Module = instantiate_from_config(loss_config)
|
||||||
|
self.regularization: AbstractRegularizer = instantiate_from_config(
|
||||||
|
regularizer_config
|
||||||
|
)
|
||||||
self.optimizer_config = default(
|
self.optimizer_config = default(
|
||||||
optimizer_config, {"target": "torch.optim.Adam"}
|
optimizer_config, {"target": "torch.optim.Adam"}
|
||||||
)
|
)
|
||||||
|
self.diff_boost_factor = diff_boost_factor
|
||||||
|
self.disc_start_iter = disc_start_iter
|
||||||
self.lr_g_factor = lr_g_factor
|
self.lr_g_factor = lr_g_factor
|
||||||
|
self.trainable_ae_params = trainable_ae_params
|
||||||
|
if self.trainable_ae_params is not None:
|
||||||
|
self.ae_optimizer_args = default(
|
||||||
|
ae_optimizer_args,
|
||||||
|
[{} for _ in range(len(self.trainable_ae_params))],
|
||||||
|
)
|
||||||
|
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
|
||||||
|
else:
|
||||||
|
self.ae_optimizer_args = [{}] # makes type consitent
|
||||||
|
|
||||||
|
self.trainable_disc_params = trainable_disc_params
|
||||||
|
if self.trainable_disc_params is not None:
|
||||||
|
self.disc_optimizer_args = default(
|
||||||
|
disc_optimizer_args,
|
||||||
|
[{} for _ in range(len(self.trainable_disc_params))],
|
||||||
|
)
|
||||||
|
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
|
||||||
|
else:
|
||||||
|
self.disc_optimizer_args = [{}] # makes type consitent
|
||||||
|
|
||||||
|
if ckpt_path is not None:
|
||||||
|
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
|
||||||
|
logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
|
||||||
|
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
||||||
|
self.additional_decode_keys = set(default(additional_decode_keys, []))
|
||||||
|
|
||||||
def get_input(self, batch: Dict) -> torch.Tensor:
|
def get_input(self, batch: Dict) -> torch.Tensor:
|
||||||
# assuming unified data format, dataloader returns a dict.
|
# assuming unified data format, dataloader returns a dict.
|
||||||
# image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
|
# image tensors should be scaled to -1 ... 1 and in channels-first
|
||||||
|
# format (e.g., bchw instead if bhwc)
|
||||||
return batch[self.input_key]
|
return batch[self.input_key]
|
||||||
|
|
||||||
def get_autoencoder_params(self) -> list:
|
def get_autoencoder_params(self) -> list:
|
||||||
params = (
|
params = []
|
||||||
list(self.encoder.parameters())
|
if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
|
||||||
+ list(self.decoder.parameters())
|
params += list(self.loss.get_trainable_autoencoder_parameters())
|
||||||
+ list(self.regularization.get_trainable_parameters())
|
if hasattr(self.regularization, "get_trainable_parameters"):
|
||||||
+ list(self.loss.get_trainable_autoencoder_parameters())
|
params += list(self.regularization.get_trainable_parameters())
|
||||||
)
|
params = params + list(self.encoder.parameters())
|
||||||
|
params = params + list(self.decoder.parameters())
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def get_discriminator_params(self) -> list:
|
def get_discriminator_params(self) -> list:
|
||||||
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
if hasattr(self.loss, "get_trainable_parameters"):
|
||||||
|
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
||||||
|
else:
|
||||||
|
params = []
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def get_last_layer(self):
|
def get_last_layer(self):
|
||||||
return self.decoder.get_last_layer()
|
return self.decoder.get_last_layer()
|
||||||
|
|
||||||
def encode(self, x: Any, return_reg_log: bool = False) -> Any:
|
def encode(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
return_reg_log: bool = False,
|
||||||
|
unregularized: bool = False,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||||
z = self.encoder(x)
|
z = self.encoder(x)
|
||||||
|
if unregularized:
|
||||||
|
return z, dict()
|
||||||
z, reg_log = self.regularization(z)
|
z, reg_log = self.regularization(z)
|
||||||
if return_reg_log:
|
if return_reg_log:
|
||||||
return z, reg_log
|
return z, reg_log
|
||||||
return z
|
return z
|
||||||
|
|
||||||
def decode(self, z: Any) -> torch.Tensor:
|
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||||
x = self.decoder(z)
|
x = self.decoder(z, **kwargs)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
def forward(
|
||||||
|
self, x: torch.Tensor, **additional_decode_kwargs
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||||
z, reg_log = self.encode(x, return_reg_log=True)
|
z, reg_log = self.encode(x, return_reg_log=True)
|
||||||
dec = self.decode(z)
|
dec = self.decode(z, **additional_decode_kwargs)
|
||||||
return z, dec, reg_log
|
return z, dec, reg_log
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
|
def inner_training_step(
|
||||||
|
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
||||||
|
) -> torch.Tensor:
|
||||||
x = self.get_input(batch)
|
x = self.get_input(batch)
|
||||||
z, xrec, regularization_log = self(x)
|
additional_decode_kwargs = {
|
||||||
|
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
||||||
|
}
|
||||||
|
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
||||||
|
if hasattr(self.loss, "forward_keys"):
|
||||||
|
extra_info = {
|
||||||
|
"z": z,
|
||||||
|
"optimizer_idx": optimizer_idx,
|
||||||
|
"global_step": self.global_step,
|
||||||
|
"last_layer": self.get_last_layer(),
|
||||||
|
"split": "train",
|
||||||
|
"regularization_log": regularization_log,
|
||||||
|
"autoencoder": self,
|
||||||
|
}
|
||||||
|
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
||||||
|
else:
|
||||||
|
extra_info = dict()
|
||||||
|
|
||||||
if optimizer_idx == 0:
|
if optimizer_idx == 0:
|
||||||
# autoencode
|
# autoencode
|
||||||
aeloss, log_dict_ae = self.loss(
|
out_loss = self.loss(x, xrec, **extra_info)
|
||||||
regularization_log,
|
if isinstance(out_loss, tuple):
|
||||||
x,
|
aeloss, log_dict_ae = out_loss
|
||||||
xrec,
|
else:
|
||||||
optimizer_idx,
|
# simple loss function
|
||||||
self.global_step,
|
aeloss = out_loss
|
||||||
last_layer=self.get_last_layer(),
|
log_dict_ae = {"train/loss/rec": aeloss.detach()}
|
||||||
split="train",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.log_dict(
|
self.log_dict(
|
||||||
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
log_dict_ae,
|
||||||
|
prog_bar=False,
|
||||||
|
logger=True,
|
||||||
|
on_step=True,
|
||||||
|
on_epoch=True,
|
||||||
|
sync_dist=False,
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
"loss",
|
||||||
|
aeloss.mean().detach(),
|
||||||
|
prog_bar=True,
|
||||||
|
logger=False,
|
||||||
|
on_epoch=False,
|
||||||
|
on_step=True,
|
||||||
)
|
)
|
||||||
return aeloss
|
return aeloss
|
||||||
|
elif optimizer_idx == 1:
|
||||||
if optimizer_idx == 1:
|
|
||||||
# discriminator
|
# discriminator
|
||||||
discloss, log_dict_disc = self.loss(
|
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
||||||
regularization_log,
|
# -> discriminator always needs to return a tuple
|
||||||
x,
|
|
||||||
xrec,
|
|
||||||
optimizer_idx,
|
|
||||||
self.global_step,
|
|
||||||
last_layer=self.get_last_layer(),
|
|
||||||
split="train",
|
|
||||||
)
|
|
||||||
self.log_dict(
|
self.log_dict(
|
||||||
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
||||||
)
|
)
|
||||||
return discloss
|
return discloss
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx) -> Dict:
|
def training_step(self, batch: dict, batch_idx: int):
|
||||||
|
opts = self.optimizers()
|
||||||
|
if not isinstance(opts, list):
|
||||||
|
# Non-adversarial case
|
||||||
|
opts = [opts]
|
||||||
|
optimizer_idx = batch_idx % len(opts)
|
||||||
|
if self.global_step < self.disc_start_iter:
|
||||||
|
optimizer_idx = 0
|
||||||
|
opt = opts[optimizer_idx]
|
||||||
|
opt.zero_grad()
|
||||||
|
with opt.toggle_model():
|
||||||
|
loss = self.inner_training_step(
|
||||||
|
batch, batch_idx, optimizer_idx=optimizer_idx
|
||||||
|
)
|
||||||
|
self.manual_backward(loss)
|
||||||
|
opt.step()
|
||||||
|
|
||||||
|
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
|
||||||
log_dict = self._validation_step(batch, batch_idx)
|
log_dict = self._validation_step(batch, batch_idx)
|
||||||
with self.ema_scope():
|
with self.ema_scope():
|
||||||
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
||||||
log_dict.update(log_dict_ema)
|
log_dict.update(log_dict_ema)
|
||||||
return log_dict
|
return log_dict
|
||||||
|
|
||||||
def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
|
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
|
||||||
x = self.get_input(batch)
|
x = self.get_input(batch)
|
||||||
|
|
||||||
z, xrec, regularization_log = self(x)
|
z, xrec, regularization_log = self(x)
|
||||||
aeloss, log_dict_ae = self.loss(
|
if hasattr(self.loss, "forward_keys"):
|
||||||
regularization_log,
|
extra_info = {
|
||||||
x,
|
"z": z,
|
||||||
xrec,
|
"optimizer_idx": 0,
|
||||||
0,
|
"global_step": self.global_step,
|
||||||
self.global_step,
|
"last_layer": self.get_last_layer(),
|
||||||
last_layer=self.get_last_layer(),
|
"split": "val" + postfix,
|
||||||
split="val" + postfix,
|
"regularization_log": regularization_log,
|
||||||
|
"autoencoder": self,
|
||||||
|
}
|
||||||
|
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
||||||
|
else:
|
||||||
|
extra_info = dict()
|
||||||
|
out_loss = self.loss(x, xrec, **extra_info)
|
||||||
|
if isinstance(out_loss, tuple):
|
||||||
|
aeloss, log_dict_ae = out_loss
|
||||||
|
else:
|
||||||
|
# simple loss function
|
||||||
|
aeloss = out_loss
|
||||||
|
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
|
||||||
|
full_log_dict = log_dict_ae
|
||||||
|
|
||||||
|
if "optimizer_idx" in extra_info:
|
||||||
|
extra_info["optimizer_idx"] = 1
|
||||||
|
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
||||||
|
full_log_dict.update(log_dict_disc)
|
||||||
|
self.log(
|
||||||
|
f"val{postfix}/loss/rec",
|
||||||
|
log_dict_ae[f"val{postfix}/loss/rec"],
|
||||||
|
sync_dist=True,
|
||||||
)
|
)
|
||||||
|
self.log_dict(full_log_dict, sync_dist=True)
|
||||||
|
return full_log_dict
|
||||||
|
|
||||||
discloss, log_dict_disc = self.loss(
|
def get_param_groups(
|
||||||
regularization_log,
|
self, parameter_names: List[List[str]], optimizer_args: List[dict]
|
||||||
x,
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
xrec,
|
groups = []
|
||||||
1,
|
num_params = 0
|
||||||
self.global_step,
|
for names, args in zip(parameter_names, optimizer_args):
|
||||||
last_layer=self.get_last_layer(),
|
params = []
|
||||||
split="val" + postfix,
|
for pattern_ in names:
|
||||||
)
|
pattern_params = []
|
||||||
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
|
pattern = re.compile(pattern_)
|
||||||
log_dict_ae.update(log_dict_disc)
|
for p_name, param in self.named_parameters():
|
||||||
self.log_dict(log_dict_ae)
|
if re.match(pattern, p_name):
|
||||||
return log_dict_ae
|
pattern_params.append(param)
|
||||||
|
num_params += param.numel()
|
||||||
def configure_optimizers(self) -> Any:
|
if len(pattern_params) == 0:
|
||||||
ae_params = self.get_autoencoder_params()
|
logpy.warn(f"Did not find parameters for pattern {pattern_}")
|
||||||
disc_params = self.get_discriminator_params()
|
params.extend(pattern_params)
|
||||||
|
groups.append({"params": params, **args})
|
||||||
|
return groups, num_params
|
||||||
|
|
||||||
|
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
|
||||||
|
if self.trainable_ae_params is None:
|
||||||
|
ae_params = self.get_autoencoder_params()
|
||||||
|
else:
|
||||||
|
ae_params, num_ae_params = self.get_param_groups(
|
||||||
|
self.trainable_ae_params, self.ae_optimizer_args
|
||||||
|
)
|
||||||
|
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
||||||
|
if self.trainable_disc_params is None:
|
||||||
|
disc_params = self.get_discriminator_params()
|
||||||
|
else:
|
||||||
|
disc_params, num_disc_params = self.get_param_groups(
|
||||||
|
self.trainable_disc_params, self.disc_optimizer_args
|
||||||
|
)
|
||||||
|
logpy.info(
|
||||||
|
f"Number of trainable discriminator parameters: {num_disc_params:,}"
|
||||||
|
)
|
||||||
opt_ae = self.instantiate_optimizer_from_config(
|
opt_ae = self.instantiate_optimizer_from_config(
|
||||||
ae_params,
|
ae_params,
|
||||||
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
||||||
self.optimizer_config,
|
self.optimizer_config,
|
||||||
)
|
)
|
||||||
opt_disc = self.instantiate_optimizer_from_config(
|
opts = [opt_ae]
|
||||||
disc_params, self.learning_rate, self.optimizer_config
|
if len(disc_params) > 0:
|
||||||
)
|
opt_disc = self.instantiate_optimizer_from_config(
|
||||||
|
disc_params, self.learning_rate, self.optimizer_config
|
||||||
|
)
|
||||||
|
opts.append(opt_disc)
|
||||||
|
|
||||||
return [opt_ae, opt_disc], []
|
return opts
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(self, batch: Dict, **kwargs) -> Dict:
|
def log_images(
|
||||||
|
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
||||||
|
) -> dict:
|
||||||
log = dict()
|
log = dict()
|
||||||
|
additional_decode_kwargs = {}
|
||||||
x = self.get_input(batch)
|
x = self.get_input(batch)
|
||||||
_, xrec, _ = self(x)
|
additional_decode_kwargs.update(
|
||||||
|
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
||||||
|
)
|
||||||
|
|
||||||
|
_, xrec, _ = self(x, **additional_decode_kwargs)
|
||||||
log["inputs"] = x
|
log["inputs"] = x
|
||||||
log["reconstructions"] = xrec
|
log["reconstructions"] = xrec
|
||||||
|
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
|
||||||
|
diff.clamp_(0, 1.0)
|
||||||
|
log["diff"] = 2.0 * diff - 1.0
|
||||||
|
# diff_boost shows location of small errors, by boosting their
|
||||||
|
# brightness.
|
||||||
|
log["diff_boost"] = (
|
||||||
|
2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
|
||||||
|
)
|
||||||
|
if hasattr(self.loss, "log_images"):
|
||||||
|
log.update(self.loss.log_images(x, xrec))
|
||||||
with self.ema_scope():
|
with self.ema_scope():
|
||||||
_, xrec_ema, _ = self(x)
|
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
|
||||||
log["reconstructions_ema"] = xrec_ema
|
log["reconstructions_ema"] = xrec_ema
|
||||||
|
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
||||||
|
diff_ema.clamp_(0, 1.0)
|
||||||
|
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
||||||
|
log["diff_boost_ema"] = (
|
||||||
|
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
||||||
|
)
|
||||||
|
if additional_log_kwargs:
|
||||||
|
additional_decode_kwargs.update(additional_log_kwargs)
|
||||||
|
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
||||||
|
log_str = "reconstructions-" + "-".join(
|
||||||
|
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
|
||||||
|
)
|
||||||
|
log[log_str] = xrec_add
|
||||||
return log
|
return log
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderKL(AutoencodingEngine):
|
class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||||
def __init__(self, embed_dim: int, **kwargs):
|
def __init__(self, embed_dim: int, **kwargs):
|
||||||
|
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
||||||
ddconfig = kwargs.pop("ddconfig")
|
ddconfig = kwargs.pop("ddconfig")
|
||||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||||
ignore_keys = kwargs.pop("ignore_keys", ())
|
ckpt_engine = kwargs.pop("ckpt_engine", None)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
encoder_config={"target": "torch.nn.Identity"},
|
encoder_config={
|
||||||
decoder_config={"target": "torch.nn.Identity"},
|
"target": "sgm.modules.diffusionmodules.model.Encoder",
|
||||||
regularizer_config={"target": "torch.nn.Identity"},
|
"params": ddconfig,
|
||||||
loss_config=kwargs.pop("lossconfig"),
|
},
|
||||||
|
decoder_config={
|
||||||
|
"target": "sgm.modules.diffusionmodules.model.Decoder",
|
||||||
|
"params": ddconfig,
|
||||||
|
},
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
assert ddconfig["double_z"]
|
self.quant_conv = torch.nn.Conv2d(
|
||||||
self.encoder = Encoder(**ddconfig)
|
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
||||||
self.decoder = Decoder(**ddconfig)
|
(1 + ddconfig["double_z"]) * embed_dim,
|
||||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
1,
|
||||||
|
)
|
||||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
if ckpt_path is not None:
|
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
|
||||||
|
|
||||||
def encode(self, x):
|
def get_autoencoder_params(self) -> list:
|
||||||
assert (
|
params = super().get_autoencoder_params()
|
||||||
not self.training
|
return params
|
||||||
), f"{self.__class__.__name__} only supports inference currently"
|
|
||||||
h = self.encoder(x)
|
def encode(
|
||||||
moments = self.quant_conv(h)
|
self, x: torch.Tensor, return_reg_log: bool = False
|
||||||
posterior = DiagonalGaussianDistribution(moments)
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||||
return posterior
|
if self.max_batch_size is None:
|
||||||
|
z = self.encoder(x)
|
||||||
|
z = self.quant_conv(z)
|
||||||
|
else:
|
||||||
|
N = x.shape[0]
|
||||||
|
bs = self.max_batch_size
|
||||||
|
n_batches = int(math.ceil(N / bs))
|
||||||
|
z = list()
|
||||||
|
for i_batch in range(n_batches):
|
||||||
|
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
||||||
|
z_batch = self.quant_conv(z_batch)
|
||||||
|
z.append(z_batch)
|
||||||
|
z = torch.cat(z, 0)
|
||||||
|
|
||||||
|
z, reg_log = self.regularization(z)
|
||||||
|
if return_reg_log:
|
||||||
|
return z, reg_log
|
||||||
|
return z
|
||||||
|
|
||||||
|
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
||||||
|
if self.max_batch_size is None:
|
||||||
|
dec = self.post_quant_conv(z)
|
||||||
|
dec = self.decoder(dec, **decoder_kwargs)
|
||||||
|
else:
|
||||||
|
N = z.shape[0]
|
||||||
|
bs = self.max_batch_size
|
||||||
|
n_batches = int(math.ceil(N / bs))
|
||||||
|
dec = list()
|
||||||
|
for i_batch in range(n_batches):
|
||||||
|
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
||||||
|
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
||||||
|
dec.append(dec_batch)
|
||||||
|
dec = torch.cat(dec, 0)
|
||||||
|
|
||||||
def decode(self, z, **decoder_kwargs):
|
|
||||||
z = self.post_quant_conv(z)
|
|
||||||
dec = self.decoder(z, **decoder_kwargs)
|
|
||||||
return dec
|
return dec
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderKLInferenceWrapper(AutoencoderKL):
|
class AutoencoderKL(AutoencodingEngineLegacy):
|
||||||
def encode(self, x):
|
def __init__(self, **kwargs):
|
||||||
return super().encode(x).sample()
|
if "lossconfig" in kwargs:
|
||||||
|
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
||||||
|
super().__init__(
|
||||||
|
regularizer_config={
|
||||||
|
"target": (
|
||||||
|
"sgm.modules.autoencoding.regularizers"
|
||||||
|
".DiagonalGaussianRegularizer"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
n_embed: int,
|
||||||
|
sane_index_shape: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if "lossconfig" in kwargs:
|
||||||
|
logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
|
||||||
|
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
||||||
|
super().__init__(
|
||||||
|
regularizer_config={
|
||||||
|
"target": (
|
||||||
|
"sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
|
||||||
|
),
|
||||||
|
"params": {
|
||||||
|
"n_e": n_embed,
|
||||||
|
"e_dim": embed_dim,
|
||||||
|
"sane_index_shape": sane_index_shape,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IdentityFirstStage(AbstractAutoencoder):
|
class IdentityFirstStage(AbstractAutoencoder):
|
||||||
@@ -333,3 +558,58 @@ class IdentityFirstStage(AbstractAutoencoder):
|
|||||||
|
|
||||||
def decode(self, x: Any, *args, **kwargs) -> Any:
|
def decode(self, x: Any, *args, **kwargs) -> Any:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AEIntegerWrapper(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
|
||||||
|
regularization_key: str = "regularization",
|
||||||
|
encoder_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
assert hasattr(model, "encode") and hasattr(
|
||||||
|
model, "decode"
|
||||||
|
), "Need AE interface"
|
||||||
|
self.regularization = get_nested_attribute(model, regularization_key)
|
||||||
|
self.shape = shape
|
||||||
|
self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
|
||||||
|
|
||||||
|
def encode(self, x) -> torch.Tensor:
|
||||||
|
assert (
|
||||||
|
not self.training
|
||||||
|
), f"{self.__class__.__name__} only supports inference currently"
|
||||||
|
_, log = self.model.encode(x, **self.encoder_kwargs)
|
||||||
|
assert isinstance(log, dict)
|
||||||
|
inds = log["min_encoding_indices"]
|
||||||
|
return rearrange(inds, "b ... -> b (...)")
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# expect inds shape (b, s) with s = h*w
|
||||||
|
shape = default(shape, self.shape) # Optional[(h, w)]
|
||||||
|
if shape is not None:
|
||||||
|
assert len(shape) == 2, f"Unhandeled shape {shape}"
|
||||||
|
inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
|
||||||
|
h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
|
||||||
|
h = rearrange(h, "b h w c -> b c h w")
|
||||||
|
return self.model.decode(h)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
if "lossconfig" in kwargs:
|
||||||
|
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
||||||
|
super().__init__(
|
||||||
|
regularizer_config={
|
||||||
|
"target": (
|
||||||
|
"sgm.modules.autoencoding.regularizers"
|
||||||
|
".DiagonalGaussianRegularizer"
|
||||||
|
),
|
||||||
|
"params": {"sample": False},
|
||||||
|
},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
|
import math
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
@@ -8,15 +9,11 @@ from safetensors.torch import load_file as load_safetensors
|
|||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
|
||||||
from ..modules import UNCONDITIONAL_CONFIG
|
from ..modules import UNCONDITIONAL_CONFIG
|
||||||
|
from ..modules.autoencoding.temporal_ae import VideoDecoder
|
||||||
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
||||||
from ..modules.ema import LitEma
|
from ..modules.ema import LitEma
|
||||||
from ..util import (
|
from ..util import (default, disabled_train, get_obj_from_str,
|
||||||
default,
|
instantiate_from_config, log_txt_as_img)
|
||||||
disabled_train,
|
|
||||||
get_obj_from_str,
|
|
||||||
instantiate_from_config,
|
|
||||||
log_txt_as_img,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DiffusionEngine(pl.LightningModule):
|
class DiffusionEngine(pl.LightningModule):
|
||||||
@@ -40,6 +37,7 @@ class DiffusionEngine(pl.LightningModule):
|
|||||||
log_keys: Union[List, None] = None,
|
log_keys: Union[List, None] = None,
|
||||||
no_cond_log: bool = False,
|
no_cond_log: bool = False,
|
||||||
compile_model: bool = False,
|
compile_model: bool = False,
|
||||||
|
en_and_decode_n_samples_a_time: Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.log_keys = log_keys
|
self.log_keys = log_keys
|
||||||
@@ -82,6 +80,8 @@ class DiffusionEngine(pl.LightningModule):
|
|||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path)
|
self.init_from_ckpt(ckpt_path)
|
||||||
|
|
||||||
|
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
|
||||||
|
|
||||||
def init_from_ckpt(
|
def init_from_ckpt(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
@@ -117,14 +117,35 @@ class DiffusionEngine(pl.LightningModule):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode_first_stage(self, z):
|
def decode_first_stage(self, z):
|
||||||
z = 1.0 / self.scale_factor * z
|
z = 1.0 / self.scale_factor * z
|
||||||
|
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
||||||
|
|
||||||
|
n_rounds = math.ceil(z.shape[0] / n_samples)
|
||||||
|
all_out = []
|
||||||
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
||||||
out = self.first_stage_model.decode(z)
|
for n in range(n_rounds):
|
||||||
|
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
||||||
|
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
out = self.first_stage_model.decode(
|
||||||
|
z[n * n_samples : (n + 1) * n_samples], **kwargs
|
||||||
|
)
|
||||||
|
all_out.append(out)
|
||||||
|
out = torch.cat(all_out, dim=0)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def encode_first_stage(self, x):
|
def encode_first_stage(self, x):
|
||||||
|
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
|
||||||
|
n_rounds = math.ceil(x.shape[0] / n_samples)
|
||||||
|
all_out = []
|
||||||
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
||||||
z = self.first_stage_model.encode(x)
|
for n in range(n_rounds):
|
||||||
|
out = self.first_stage_model.encode(
|
||||||
|
x[n * n_samples : (n + 1) * n_samples]
|
||||||
|
)
|
||||||
|
all_out.append(out)
|
||||||
|
z = torch.cat(all_out, dim=0)
|
||||||
z = self.scale_factor * z
|
z = self.scale_factor * z
|
||||||
return z
|
return z
|
||||||
|
|
||||||
@@ -258,14 +279,10 @@ class DiffusionEngine(pl.LightningModule):
|
|||||||
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
elif isinstance(x, Union[List, ListConfig]):
|
elif isinstance(x, (List, ListConfig)):
|
||||||
if isinstance(x[0], str):
|
if isinstance(x[0], str):
|
||||||
# strings
|
# strings
|
||||||
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
||||||
elif isinstance(x[0], Union[ListConfig, List]):
|
|
||||||
# # case: videos processed
|
|
||||||
x = [xx[0] for xx in x]
|
|
||||||
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@@ -7,6 +8,9 @@ import torch.nn.functional as F
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
logpy = logging.getLogger(__name__)
|
||||||
|
|
||||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||||
SDP_IS_AVAILABLE = True
|
SDP_IS_AVAILABLE = True
|
||||||
@@ -36,9 +40,10 @@ else:
|
|||||||
SDP_IS_AVAILABLE = False
|
SDP_IS_AVAILABLE = False
|
||||||
sdp_kernel = nullcontext
|
sdp_kernel = nullcontext
|
||||||
BACKEND_MAP = {}
|
BACKEND_MAP = {}
|
||||||
print(
|
logpy.warn(
|
||||||
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
|
f"No SDP backend available, likely because you are running in pytorch "
|
||||||
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
|
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
|
||||||
|
f"You might want to consider upgrading."
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -48,9 +53,9 @@ try:
|
|||||||
XFORMERS_IS_AVAILABLE = True
|
XFORMERS_IS_AVAILABLE = True
|
||||||
except:
|
except:
|
||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
print("no module 'xformers'. Processing without...")
|
logpy.warn("no module 'xformers'. Processing without...")
|
||||||
|
|
||||||
from .diffusionmodules.util import checkpoint
|
# from .diffusionmodules.util import mixed_checkpoint as checkpoint
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
@@ -146,6 +151,62 @@ class LinearAttention(nn.Module):
|
|||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
ATTENTION_MODES = ("xformers", "torch", "math")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
qk_scale: Optional[float] = None,
|
||||||
|
attn_drop: float = 0.0,
|
||||||
|
proj_drop: float = 0.0,
|
||||||
|
attn_mode: str = "xformers",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = qk_scale or head_dim**-0.5
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
assert attn_mode in self.ATTENTION_MODES
|
||||||
|
self.attn_mode = attn_mode
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
B, L, C = x.shape
|
||||||
|
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
if self.attn_mode == "torch":
|
||||||
|
qkv = rearrange(
|
||||||
|
qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
||||||
|
).float()
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
||||||
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||||
|
x = rearrange(x, "B H L D -> B L (H D)")
|
||||||
|
elif self.attn_mode == "xformers":
|
||||||
|
qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
|
||||||
|
x = xformers.ops.memory_efficient_attention(q, k, v)
|
||||||
|
x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
|
||||||
|
elif self.attn_mode == "math":
|
||||||
|
qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
|
||||||
|
else:
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SpatialSelfAttention(nn.Module):
|
class SpatialSelfAttention(nn.Module):
|
||||||
def __init__(self, in_channels):
|
def __init__(self, in_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -289,9 +350,10 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
print(
|
logpy.debug(
|
||||||
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
|
||||||
f"{heads} heads with a dimension of {dim_head}."
|
f"context_dim is {context_dim} and using {heads} heads with a "
|
||||||
|
f"dimension of {dim_head}."
|
||||||
)
|
)
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@@ -352,9 +414,29 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# actually compute the attention, what we cannot get enough of
|
# actually compute the attention, what we cannot get enough of
|
||||||
out = xformers.ops.memory_efficient_attention(
|
if version.parse(xformers.__version__) >= version.parse("0.0.21"):
|
||||||
q, k, v, attn_bias=None, op=self.attention_op
|
# NOTE: workaround for
|
||||||
)
|
# https://github.com/facebookresearch/xformers/issues/845
|
||||||
|
max_bs = 32768
|
||||||
|
N = q.shape[0]
|
||||||
|
n_batches = math.ceil(N / max_bs)
|
||||||
|
out = list()
|
||||||
|
for i_batch in range(n_batches):
|
||||||
|
batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
|
||||||
|
out.append(
|
||||||
|
xformers.ops.memory_efficient_attention(
|
||||||
|
q[batch],
|
||||||
|
k[batch],
|
||||||
|
v[batch],
|
||||||
|
attn_bias=None,
|
||||||
|
op=self.attention_op,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
out = torch.cat(out, 0)
|
||||||
|
else:
|
||||||
|
out = xformers.ops.memory_efficient_attention(
|
||||||
|
q, k, v, attn_bias=None, op=self.attention_op
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: Use this directly in the attention operation, as a bias
|
# TODO: Use this directly in the attention operation, as a bias
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
@@ -393,21 +475,24 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
assert attn_mode in self.ATTENTION_MODES
|
assert attn_mode in self.ATTENTION_MODES
|
||||||
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
||||||
print(
|
logpy.warn(
|
||||||
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
|
f"Attention mode '{attn_mode}' is not available. Falling "
|
||||||
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
|
f"back to native attention. This is not a problem in "
|
||||||
|
f"Pytorch >= 2.0. FYI, you are running with PyTorch "
|
||||||
|
f"version {torch.__version__}."
|
||||||
)
|
)
|
||||||
attn_mode = "softmax"
|
attn_mode = "softmax"
|
||||||
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
||||||
print(
|
logpy.warn(
|
||||||
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
|
"We do not support vanilla attention anymore, as it is too "
|
||||||
|
"expensive. Sorry."
|
||||||
)
|
)
|
||||||
if not XFORMERS_IS_AVAILABLE:
|
if not XFORMERS_IS_AVAILABLE:
|
||||||
assert (
|
assert (
|
||||||
False
|
False
|
||||||
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||||
else:
|
else:
|
||||||
print("Falling back to xformers efficient attention.")
|
logpy.info("Falling back to xformers efficient attention.")
|
||||||
attn_mode = "softmax-xformers"
|
attn_mode = "softmax-xformers"
|
||||||
attn_cls = self.ATTENTION_MODES[attn_mode]
|
attn_cls = self.ATTENTION_MODES[attn_mode]
|
||||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||||
@@ -437,7 +522,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
self.norm3 = nn.LayerNorm(dim)
|
self.norm3 = nn.LayerNorm(dim)
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
if self.checkpoint:
|
if self.checkpoint:
|
||||||
print(f"{self.__class__.__name__} is using checkpointing")
|
logpy.debug(f"{self.__class__.__name__} is using checkpointing")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
||||||
@@ -456,9 +541,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
||||||
return checkpoint(
|
if self.checkpoint:
|
||||||
self._forward, (x, context), self.parameters(), self.checkpoint
|
# inputs = {"x": x, "context": context}
|
||||||
)
|
return checkpoint(self._forward, x, context)
|
||||||
|
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
||||||
|
else:
|
||||||
|
return self._forward(**kwargs)
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
||||||
@@ -518,9 +606,9 @@ class BasicTransformerSingleLayerBlock(nn.Module):
|
|||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
|
||||||
def forward(self, x, context=None):
|
def forward(self, x, context=None):
|
||||||
return checkpoint(
|
# inputs = {"x": x, "context": context}
|
||||||
self._forward, (x, context), self.parameters(), self.checkpoint
|
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
||||||
)
|
return checkpoint(self._forward, x, context)
|
||||||
|
|
||||||
def _forward(self, x, context=None):
|
def _forward(self, x, context=None):
|
||||||
x = self.attn1(self.norm1(x), context=context) + x
|
x = self.attn1(self.norm1(x), context=context) + x
|
||||||
@@ -554,18 +642,20 @@ class SpatialTransformer(nn.Module):
|
|||||||
sdp_backend=None,
|
sdp_backend=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
print(
|
logpy.debug(
|
||||||
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
|
f"constructing {self.__class__.__name__} of depth {depth} w/ "
|
||||||
|
f"{in_channels} channels and {n_heads} heads."
|
||||||
)
|
)
|
||||||
from omegaconf import ListConfig
|
|
||||||
|
|
||||||
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
if exists(context_dim) and not isinstance(context_dim, list):
|
||||||
context_dim = [context_dim]
|
context_dim = [context_dim]
|
||||||
if exists(context_dim) and isinstance(context_dim, list):
|
if exists(context_dim) and isinstance(context_dim, list):
|
||||||
if depth != len(context_dim):
|
if depth != len(context_dim):
|
||||||
print(
|
logpy.warn(
|
||||||
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
|
f"{self.__class__.__name__}: Found context dims "
|
||||||
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
|
f"{context_dim} of depth {len(context_dim)}, which does not "
|
||||||
|
f"match the specified 'depth' of {depth}. Setting context_dim "
|
||||||
|
f"to {depth * [context_dim[0]]} now."
|
||||||
)
|
)
|
||||||
# depth does not match context dims.
|
# depth does not match context dims.
|
||||||
assert all(
|
assert all(
|
||||||
@@ -633,315 +723,37 @@ class SpatialTransformer(nn.Module):
|
|||||||
return x + x_in
|
return x + x_in
|
||||||
|
|
||||||
|
|
||||||
def benchmark_attn():
|
class SimpleTransformer(nn.Module):
|
||||||
# Lets define a helpful benchmarking function:
|
def __init__(
|
||||||
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
|
self,
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
dim: int,
|
||||||
import torch.nn.functional as F
|
depth: int,
|
||||||
import torch.utils.benchmark as benchmark
|
heads: int,
|
||||||
|
dim_head: int,
|
||||||
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
context_dim: Optional[int] = None,
|
||||||
t0 = benchmark.Timer(
|
dropout: float = 0.0,
|
||||||
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
checkpoint: bool = True,
|
||||||
)
|
):
|
||||||
return t0.blocked_autorange().mean * 1e6
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
# Lets define the hyper-parameters of our input
|
for _ in range(depth):
|
||||||
batch_size = 32
|
self.layers.append(
|
||||||
max_sequence_len = 1024
|
BasicTransformerBlock(
|
||||||
num_heads = 32
|
dim,
|
||||||
embed_dimension = 32
|
heads,
|
||||||
|
dim_head,
|
||||||
dtype = torch.float16
|
dropout=dropout,
|
||||||
|
context_dim=context_dim,
|
||||||
query = torch.rand(
|
attn_mode="softmax-xformers",
|
||||||
batch_size,
|
checkpoint=checkpoint,
|
||||||
num_heads,
|
)
|
||||||
max_sequence_len,
|
|
||||||
embed_dimension,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
key = torch.rand(
|
|
||||||
batch_size,
|
|
||||||
num_heads,
|
|
||||||
max_sequence_len,
|
|
||||||
embed_dimension,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
value = torch.rand(
|
|
||||||
batch_size,
|
|
||||||
num_heads,
|
|
||||||
max_sequence_len,
|
|
||||||
embed_dimension,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"q/k/v shape:", query.shape, key.shape, value.shape)
|
|
||||||
|
|
||||||
# Lets explore the speed of each of the 3 implementations
|
|
||||||
from torch.backends.cuda import SDPBackend, sdp_kernel
|
|
||||||
|
|
||||||
# Helpful arguments mapper
|
|
||||||
backend_map = {
|
|
||||||
SDPBackend.MATH: {
|
|
||||||
"enable_math": True,
|
|
||||||
"enable_flash": False,
|
|
||||||
"enable_mem_efficient": False,
|
|
||||||
},
|
|
||||||
SDPBackend.FLASH_ATTENTION: {
|
|
||||||
"enable_math": False,
|
|
||||||
"enable_flash": True,
|
|
||||||
"enable_mem_efficient": False,
|
|
||||||
},
|
|
||||||
SDPBackend.EFFICIENT_ATTENTION: {
|
|
||||||
"enable_math": False,
|
|
||||||
"enable_flash": False,
|
|
||||||
"enable_mem_efficient": True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
from torch.profiler import ProfilerActivity, profile, record_function
|
|
||||||
|
|
||||||
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
|
||||||
)
|
|
||||||
with profile(
|
|
||||||
activities=activities, record_shapes=False, profile_memory=True
|
|
||||||
) as prof:
|
|
||||||
with record_function("Default detailed stats"):
|
|
||||||
for _ in range(25):
|
|
||||||
o = F.scaled_dot_product_attention(query, key, value)
|
|
||||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
|
||||||
)
|
|
||||||
with sdp_kernel(**backend_map[SDPBackend.MATH]):
|
|
||||||
with profile(
|
|
||||||
activities=activities, record_shapes=False, profile_memory=True
|
|
||||||
) as prof:
|
|
||||||
with record_function("Math implmentation stats"):
|
|
||||||
for _ in range(25):
|
|
||||||
o = F.scaled_dot_product_attention(query, key, value)
|
|
||||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
|
||||||
|
|
||||||
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
|
|
||||||
try:
|
|
||||||
print(
|
|
||||||
f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
|
||||||
)
|
)
|
||||||
except RuntimeError:
|
|
||||||
print("FlashAttention is not supported. See warnings for reasons.")
|
|
||||||
with profile(
|
|
||||||
activities=activities, record_shapes=False, profile_memory=True
|
|
||||||
) as prof:
|
|
||||||
with record_function("FlashAttention stats"):
|
|
||||||
for _ in range(25):
|
|
||||||
o = F.scaled_dot_product_attention(query, key, value)
|
|
||||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
|
||||||
|
|
||||||
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
|
def forward(
|
||||||
try:
|
self,
|
||||||
print(
|
x: torch.Tensor,
|
||||||
f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
context: Optional[torch.Tensor] = None,
|
||||||
)
|
) -> torch.Tensor:
|
||||||
except RuntimeError:
|
for layer in self.layers:
|
||||||
print("EfficientAttention is not supported. See warnings for reasons.")
|
x = layer(x, context)
|
||||||
with profile(
|
return x
|
||||||
activities=activities, record_shapes=False, profile_memory=True
|
|
||||||
) as prof:
|
|
||||||
with record_function("EfficientAttention stats"):
|
|
||||||
for _ in range(25):
|
|
||||||
o = F.scaled_dot_product_attention(query, key, value)
|
|
||||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
|
||||||
|
|
||||||
|
|
||||||
def run_model(model, x, context):
|
|
||||||
return model(x, context)
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_transformer_blocks():
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
import torch.utils.benchmark as benchmark
|
|
||||||
|
|
||||||
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
|
||||||
t0 = benchmark.Timer(
|
|
||||||
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
|
||||||
)
|
|
||||||
return t0.blocked_autorange().mean * 1e6
|
|
||||||
|
|
||||||
checkpoint = True
|
|
||||||
compile = False
|
|
||||||
|
|
||||||
batch_size = 32
|
|
||||||
h, w = 64, 64
|
|
||||||
context_len = 77
|
|
||||||
embed_dimension = 1024
|
|
||||||
context_dim = 1024
|
|
||||||
d_head = 64
|
|
||||||
|
|
||||||
transformer_depth = 4
|
|
||||||
|
|
||||||
n_heads = embed_dimension // d_head
|
|
||||||
|
|
||||||
dtype = torch.float16
|
|
||||||
|
|
||||||
model_native = SpatialTransformer(
|
|
||||||
embed_dimension,
|
|
||||||
n_heads,
|
|
||||||
d_head,
|
|
||||||
context_dim=context_dim,
|
|
||||||
use_linear=True,
|
|
||||||
use_checkpoint=checkpoint,
|
|
||||||
attn_type="softmax",
|
|
||||||
depth=transformer_depth,
|
|
||||||
sdp_backend=SDPBackend.FLASH_ATTENTION,
|
|
||||||
).to(device)
|
|
||||||
model_efficient_attn = SpatialTransformer(
|
|
||||||
embed_dimension,
|
|
||||||
n_heads,
|
|
||||||
d_head,
|
|
||||||
context_dim=context_dim,
|
|
||||||
use_linear=True,
|
|
||||||
depth=transformer_depth,
|
|
||||||
use_checkpoint=checkpoint,
|
|
||||||
attn_type="softmax-xformers",
|
|
||||||
).to(device)
|
|
||||||
if not checkpoint and compile:
|
|
||||||
print("compiling models")
|
|
||||||
model_native = torch.compile(model_native)
|
|
||||||
model_efficient_attn = torch.compile(model_efficient_attn)
|
|
||||||
|
|
||||||
x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
|
|
||||||
c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
from torch.profiler import ProfilerActivity, profile, record_function
|
|
||||||
|
|
||||||
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
|
||||||
print(
|
|
||||||
f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(75 * "+")
|
|
||||||
print("NATIVE")
|
|
||||||
print(75 * "+")
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
with profile(
|
|
||||||
activities=activities, record_shapes=False, profile_memory=True
|
|
||||||
) as prof:
|
|
||||||
with record_function("NativeAttention stats"):
|
|
||||||
for _ in range(25):
|
|
||||||
model_native(x, c)
|
|
||||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
|
||||||
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
|
|
||||||
|
|
||||||
print(75 * "+")
|
|
||||||
print("Xformers")
|
|
||||||
print(75 * "+")
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
with profile(
|
|
||||||
activities=activities, record_shapes=False, profile_memory=True
|
|
||||||
) as prof:
|
|
||||||
with record_function("xformers stats"):
|
|
||||||
for _ in range(25):
|
|
||||||
model_efficient_attn(x, c)
|
|
||||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
|
||||||
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
|
|
||||||
|
|
||||||
|
|
||||||
def test01():
|
|
||||||
# conv1x1 vs linear
|
|
||||||
from ..util import count_params
|
|
||||||
|
|
||||||
conv = nn.Conv2d(3, 32, kernel_size=1).cuda()
|
|
||||||
print(count_params(conv))
|
|
||||||
linear = torch.nn.Linear(3, 32).cuda()
|
|
||||||
print(count_params(linear))
|
|
||||||
|
|
||||||
print(conv.weight.shape)
|
|
||||||
|
|
||||||
# use same initialization
|
|
||||||
linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
|
|
||||||
linear.bias = torch.nn.Parameter(conv.bias)
|
|
||||||
|
|
||||||
print(linear.weight.shape)
|
|
||||||
|
|
||||||
x = torch.randn(11, 3, 64, 64).cuda()
|
|
||||||
|
|
||||||
xr = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
|
||||||
print(xr.shape)
|
|
||||||
out_linear = linear(xr)
|
|
||||||
print(out_linear.mean(), out_linear.shape)
|
|
||||||
|
|
||||||
out_conv = conv(x)
|
|
||||||
print(out_conv.mean(), out_conv.shape)
|
|
||||||
print("done with test01.\n")
|
|
||||||
|
|
||||||
|
|
||||||
def test02():
|
|
||||||
# try cosine flash attention
|
|
||||||
import time
|
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
print("testing cosine flash attention...")
|
|
||||||
DIM = 1024
|
|
||||||
SEQLEN = 4096
|
|
||||||
BS = 16
|
|
||||||
|
|
||||||
print(" softmax (vanilla) first...")
|
|
||||||
model = BasicTransformerBlock(
|
|
||||||
dim=DIM,
|
|
||||||
n_heads=16,
|
|
||||||
d_head=64,
|
|
||||||
dropout=0.0,
|
|
||||||
context_dim=None,
|
|
||||||
attn_mode="softmax",
|
|
||||||
).cuda()
|
|
||||||
try:
|
|
||||||
x = torch.randn(BS, SEQLEN, DIM).cuda()
|
|
||||||
tic = time.time()
|
|
||||||
y = model(x)
|
|
||||||
toc = time.time()
|
|
||||||
print(y.shape, toc - tic)
|
|
||||||
except RuntimeError as e:
|
|
||||||
# likely oom
|
|
||||||
print(str(e))
|
|
||||||
|
|
||||||
print("\n now flash-cosine...")
|
|
||||||
model = BasicTransformerBlock(
|
|
||||||
dim=DIM,
|
|
||||||
n_heads=16,
|
|
||||||
d_head=64,
|
|
||||||
dropout=0.0,
|
|
||||||
context_dim=None,
|
|
||||||
attn_mode="flash-cosine",
|
|
||||||
).cuda()
|
|
||||||
x = torch.randn(BS, SEQLEN, DIM).cuda()
|
|
||||||
tic = time.time()
|
|
||||||
y = model(x)
|
|
||||||
toc = time.time()
|
|
||||||
print(y.shape, toc - tic)
|
|
||||||
print("done with test02.\n")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# test01()
|
|
||||||
# test02()
|
|
||||||
# test03()
|
|
||||||
|
|
||||||
# benchmark_attn()
|
|
||||||
benchmark_transformer_blocks()
|
|
||||||
|
|
||||||
print("done.")
|
|
||||||
|
|||||||
@@ -1,246 +1,7 @@
|
|||||||
from typing import Any, Union
|
__all__ = [
|
||||||
|
"GeneralLPIPSWithDiscriminator",
|
||||||
|
"LatentLPIPS",
|
||||||
|
]
|
||||||
|
|
||||||
import torch
|
from .discriminator_loss import GeneralLPIPSWithDiscriminator
|
||||||
import torch.nn as nn
|
from .lpips import LatentLPIPS
|
||||||
from einops import rearrange
|
|
||||||
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
|
||||||
from taming.modules.losses.lpips import LPIPS
|
|
||||||
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
|
|
||||||
|
|
||||||
from ....util import default, instantiate_from_config
|
|
||||||
|
|
||||||
|
|
||||||
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
|
||||||
if global_step < threshold:
|
|
||||||
weight = value
|
|
||||||
return weight
|
|
||||||
|
|
||||||
|
|
||||||
class LatentLPIPS(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
decoder_config,
|
|
||||||
perceptual_weight=1.0,
|
|
||||||
latent_weight=1.0,
|
|
||||||
scale_input_to_tgt_size=False,
|
|
||||||
scale_tgt_to_input_size=False,
|
|
||||||
perceptual_weight_on_inputs=0.0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
|
||||||
self.scale_tgt_to_input_size = scale_tgt_to_input_size
|
|
||||||
self.init_decoder(decoder_config)
|
|
||||||
self.perceptual_loss = LPIPS().eval()
|
|
||||||
self.perceptual_weight = perceptual_weight
|
|
||||||
self.latent_weight = latent_weight
|
|
||||||
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
|
|
||||||
|
|
||||||
def init_decoder(self, config):
|
|
||||||
self.decoder = instantiate_from_config(config)
|
|
||||||
if hasattr(self.decoder, "encoder"):
|
|
||||||
del self.decoder.encoder
|
|
||||||
|
|
||||||
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
|
|
||||||
log = dict()
|
|
||||||
loss = (latent_inputs - latent_predictions) ** 2
|
|
||||||
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
|
|
||||||
image_reconstructions = None
|
|
||||||
if self.perceptual_weight > 0.0:
|
|
||||||
image_reconstructions = self.decoder.decode(latent_predictions)
|
|
||||||
image_targets = self.decoder.decode(latent_inputs)
|
|
||||||
perceptual_loss = self.perceptual_loss(
|
|
||||||
image_targets.contiguous(), image_reconstructions.contiguous()
|
|
||||||
)
|
|
||||||
loss = (
|
|
||||||
self.latent_weight * loss.mean()
|
|
||||||
+ self.perceptual_weight * perceptual_loss.mean()
|
|
||||||
)
|
|
||||||
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
|
|
||||||
|
|
||||||
if self.perceptual_weight_on_inputs > 0.0:
|
|
||||||
image_reconstructions = default(
|
|
||||||
image_reconstructions, self.decoder.decode(latent_predictions)
|
|
||||||
)
|
|
||||||
if self.scale_input_to_tgt_size:
|
|
||||||
image_inputs = torch.nn.functional.interpolate(
|
|
||||||
image_inputs,
|
|
||||||
image_reconstructions.shape[2:],
|
|
||||||
mode="bicubic",
|
|
||||||
antialias=True,
|
|
||||||
)
|
|
||||||
elif self.scale_tgt_to_input_size:
|
|
||||||
image_reconstructions = torch.nn.functional.interpolate(
|
|
||||||
image_reconstructions,
|
|
||||||
image_inputs.shape[2:],
|
|
||||||
mode="bicubic",
|
|
||||||
antialias=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
perceptual_loss2 = self.perceptual_loss(
|
|
||||||
image_inputs.contiguous(), image_reconstructions.contiguous()
|
|
||||||
)
|
|
||||||
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
|
|
||||||
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
|
|
||||||
return loss, log
|
|
||||||
|
|
||||||
|
|
||||||
class GeneralLPIPSWithDiscriminator(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
disc_start: int,
|
|
||||||
logvar_init: float = 0.0,
|
|
||||||
pixelloss_weight=1.0,
|
|
||||||
disc_num_layers: int = 3,
|
|
||||||
disc_in_channels: int = 3,
|
|
||||||
disc_factor: float = 1.0,
|
|
||||||
disc_weight: float = 1.0,
|
|
||||||
perceptual_weight: float = 1.0,
|
|
||||||
disc_loss: str = "hinge",
|
|
||||||
scale_input_to_tgt_size: bool = False,
|
|
||||||
dims: int = 2,
|
|
||||||
learn_logvar: bool = False,
|
|
||||||
regularization_weights: Union[None, dict] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.dims = dims
|
|
||||||
if self.dims > 2:
|
|
||||||
print(
|
|
||||||
f"running with dims={dims}. This means that for perceptual loss calculation, "
|
|
||||||
f"the LPIPS loss will be applied to each frame independently. "
|
|
||||||
)
|
|
||||||
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
|
||||||
assert disc_loss in ["hinge", "vanilla"]
|
|
||||||
self.pixel_weight = pixelloss_weight
|
|
||||||
self.perceptual_loss = LPIPS().eval()
|
|
||||||
self.perceptual_weight = perceptual_weight
|
|
||||||
# output log variance
|
|
||||||
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
|
||||||
self.learn_logvar = learn_logvar
|
|
||||||
|
|
||||||
self.discriminator = NLayerDiscriminator(
|
|
||||||
input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False
|
|
||||||
).apply(weights_init)
|
|
||||||
self.discriminator_iter_start = disc_start
|
|
||||||
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
|
||||||
self.disc_factor = disc_factor
|
|
||||||
self.discriminator_weight = disc_weight
|
|
||||||
self.regularization_weights = default(regularization_weights, {})
|
|
||||||
|
|
||||||
def get_trainable_parameters(self) -> Any:
|
|
||||||
return self.discriminator.parameters()
|
|
||||||
|
|
||||||
def get_trainable_autoencoder_parameters(self) -> Any:
|
|
||||||
if self.learn_logvar:
|
|
||||||
yield self.logvar
|
|
||||||
yield from ()
|
|
||||||
|
|
||||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
|
||||||
if last_layer is not None:
|
|
||||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
|
||||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
|
||||||
else:
|
|
||||||
nll_grads = torch.autograd.grad(
|
|
||||||
nll_loss, self.last_layer[0], retain_graph=True
|
|
||||||
)[0]
|
|
||||||
g_grads = torch.autograd.grad(
|
|
||||||
g_loss, self.last_layer[0], retain_graph=True
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
|
||||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
|
||||||
d_weight = d_weight * self.discriminator_weight
|
|
||||||
return d_weight
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
regularization_log,
|
|
||||||
inputs,
|
|
||||||
reconstructions,
|
|
||||||
optimizer_idx,
|
|
||||||
global_step,
|
|
||||||
last_layer=None,
|
|
||||||
split="train",
|
|
||||||
weights=None,
|
|
||||||
):
|
|
||||||
if self.scale_input_to_tgt_size:
|
|
||||||
inputs = torch.nn.functional.interpolate(
|
|
||||||
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.dims > 2:
|
|
||||||
inputs, reconstructions = map(
|
|
||||||
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
|
|
||||||
(inputs, reconstructions),
|
|
||||||
)
|
|
||||||
|
|
||||||
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
|
||||||
if self.perceptual_weight > 0:
|
|
||||||
p_loss = self.perceptual_loss(
|
|
||||||
inputs.contiguous(), reconstructions.contiguous()
|
|
||||||
)
|
|
||||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
|
||||||
|
|
||||||
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
|
||||||
weighted_nll_loss = nll_loss
|
|
||||||
if weights is not None:
|
|
||||||
weighted_nll_loss = weights * nll_loss
|
|
||||||
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
|
||||||
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
|
||||||
|
|
||||||
# now the GAN part
|
|
||||||
if optimizer_idx == 0:
|
|
||||||
# generator update
|
|
||||||
logits_fake = self.discriminator(reconstructions.contiguous())
|
|
||||||
g_loss = -torch.mean(logits_fake)
|
|
||||||
|
|
||||||
if self.disc_factor > 0.0:
|
|
||||||
try:
|
|
||||||
d_weight = self.calculate_adaptive_weight(
|
|
||||||
nll_loss, g_loss, last_layer=last_layer
|
|
||||||
)
|
|
||||||
except RuntimeError:
|
|
||||||
assert not self.training
|
|
||||||
d_weight = torch.tensor(0.0)
|
|
||||||
else:
|
|
||||||
d_weight = torch.tensor(0.0)
|
|
||||||
|
|
||||||
disc_factor = adopt_weight(
|
|
||||||
self.disc_factor, global_step, threshold=self.discriminator_iter_start
|
|
||||||
)
|
|
||||||
loss = weighted_nll_loss + d_weight * disc_factor * g_loss
|
|
||||||
log = dict()
|
|
||||||
for k in regularization_log:
|
|
||||||
if k in self.regularization_weights:
|
|
||||||
loss = loss + self.regularization_weights[k] * regularization_log[k]
|
|
||||||
log[f"{split}/{k}"] = regularization_log[k].detach().mean()
|
|
||||||
|
|
||||||
log.update(
|
|
||||||
{
|
|
||||||
"{}/total_loss".format(split): loss.clone().detach().mean(),
|
|
||||||
"{}/logvar".format(split): self.logvar.detach(),
|
|
||||||
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
|
||||||
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
|
||||||
"{}/d_weight".format(split): d_weight.detach(),
|
|
||||||
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
|
||||||
"{}/g_loss".format(split): g_loss.detach().mean(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return loss, log
|
|
||||||
|
|
||||||
if optimizer_idx == 1:
|
|
||||||
# second pass for discriminator update
|
|
||||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
|
||||||
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
|
||||||
|
|
||||||
disc_factor = adopt_weight(
|
|
||||||
self.disc_factor, global_step, threshold=self.discriminator_iter_start
|
|
||||||
)
|
|
||||||
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
|
||||||
|
|
||||||
log = {
|
|
||||||
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
|
||||||
"{}/logits_real".format(split): logits_real.detach().mean(),
|
|
||||||
"{}/logits_fake".format(split): logits_fake.detach().mean(),
|
|
||||||
}
|
|
||||||
return d_loss, log
|
|
||||||
|
|||||||
306
sgm/modules/autoencoding/losses/discriminator_loss.py
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision
|
||||||
|
from einops import rearrange
|
||||||
|
from matplotlib import colormaps
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
from ....util import default, instantiate_from_config
|
||||||
|
from ..lpips.loss.lpips import LPIPS
|
||||||
|
from ..lpips.model.model import weights_init
|
||||||
|
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralLPIPSWithDiscriminator(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
disc_start: int,
|
||||||
|
logvar_init: float = 0.0,
|
||||||
|
disc_num_layers: int = 3,
|
||||||
|
disc_in_channels: int = 3,
|
||||||
|
disc_factor: float = 1.0,
|
||||||
|
disc_weight: float = 1.0,
|
||||||
|
perceptual_weight: float = 1.0,
|
||||||
|
disc_loss: str = "hinge",
|
||||||
|
scale_input_to_tgt_size: bool = False,
|
||||||
|
dims: int = 2,
|
||||||
|
learn_logvar: bool = False,
|
||||||
|
regularization_weights: Union[None, Dict[str, float]] = None,
|
||||||
|
additional_log_keys: Optional[List[str]] = None,
|
||||||
|
discriminator_config: Optional[Dict] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dims = dims
|
||||||
|
if self.dims > 2:
|
||||||
|
print(
|
||||||
|
f"running with dims={dims}. This means that for perceptual loss "
|
||||||
|
f"calculation, the LPIPS loss will be applied to each frame "
|
||||||
|
f"independently."
|
||||||
|
)
|
||||||
|
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
||||||
|
assert disc_loss in ["hinge", "vanilla"]
|
||||||
|
self.perceptual_loss = LPIPS().eval()
|
||||||
|
self.perceptual_weight = perceptual_weight
|
||||||
|
# output log variance
|
||||||
|
self.logvar = nn.Parameter(
|
||||||
|
torch.full((), logvar_init), requires_grad=learn_logvar
|
||||||
|
)
|
||||||
|
self.learn_logvar = learn_logvar
|
||||||
|
|
||||||
|
discriminator_config = default(
|
||||||
|
discriminator_config,
|
||||||
|
{
|
||||||
|
"target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
|
||||||
|
"params": {
|
||||||
|
"input_nc": disc_in_channels,
|
||||||
|
"n_layers": disc_num_layers,
|
||||||
|
"use_actnorm": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.discriminator = instantiate_from_config(discriminator_config).apply(
|
||||||
|
weights_init
|
||||||
|
)
|
||||||
|
self.discriminator_iter_start = disc_start
|
||||||
|
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
||||||
|
self.disc_factor = disc_factor
|
||||||
|
self.discriminator_weight = disc_weight
|
||||||
|
self.regularization_weights = default(regularization_weights, {})
|
||||||
|
|
||||||
|
self.forward_keys = [
|
||||||
|
"optimizer_idx",
|
||||||
|
"global_step",
|
||||||
|
"last_layer",
|
||||||
|
"split",
|
||||||
|
"regularization_log",
|
||||||
|
]
|
||||||
|
|
||||||
|
self.additional_log_keys = set(default(additional_log_keys, []))
|
||||||
|
self.additional_log_keys.update(set(self.regularization_weights.keys()))
|
||||||
|
|
||||||
|
def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
|
||||||
|
return self.discriminator.parameters()
|
||||||
|
|
||||||
|
def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
|
||||||
|
if self.learn_logvar:
|
||||||
|
yield self.logvar
|
||||||
|
yield from ()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def log_images(
|
||||||
|
self, inputs: torch.Tensor, reconstructions: torch.Tensor
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
# calc logits of real/fake
|
||||||
|
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||||
|
if len(logits_real.shape) < 4:
|
||||||
|
# Non patch-discriminator
|
||||||
|
return dict()
|
||||||
|
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
||||||
|
# -> (b, 1, h, w)
|
||||||
|
|
||||||
|
# parameters for colormapping
|
||||||
|
high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
|
||||||
|
cmap = colormaps["PiYG"] # diverging colormap
|
||||||
|
|
||||||
|
def to_colormap(logits: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""(b, 1, ...) -> (b, 3, ...)"""
|
||||||
|
logits = (logits + high) / (2 * high)
|
||||||
|
logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
|
||||||
|
# -> (b, 1, ..., 3)
|
||||||
|
logits = torch.from_numpy(logits_np).to(logits.device)
|
||||||
|
return rearrange(logits, "b 1 ... c -> b c ...")
|
||||||
|
|
||||||
|
logits_real = torch.nn.functional.interpolate(
|
||||||
|
logits_real,
|
||||||
|
size=inputs.shape[-2:],
|
||||||
|
mode="nearest",
|
||||||
|
antialias=False,
|
||||||
|
)
|
||||||
|
logits_fake = torch.nn.functional.interpolate(
|
||||||
|
logits_fake,
|
||||||
|
size=reconstructions.shape[-2:],
|
||||||
|
mode="nearest",
|
||||||
|
antialias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# alpha value of logits for overlay
|
||||||
|
alpha_real = torch.abs(logits_real) / high
|
||||||
|
alpha_fake = torch.abs(logits_fake) / high
|
||||||
|
# -> (b, 1, h, w) in range [0, 0.5]
|
||||||
|
# alpha value of lines don't really matter, since the values are the same
|
||||||
|
# for both images and logits anyway
|
||||||
|
grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
|
||||||
|
grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
|
||||||
|
grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
|
||||||
|
# -> (1, h, w)
|
||||||
|
# blend logits and images together
|
||||||
|
|
||||||
|
# prepare logits for plotting
|
||||||
|
logits_real = to_colormap(logits_real)
|
||||||
|
logits_fake = to_colormap(logits_fake)
|
||||||
|
# resize logits
|
||||||
|
# -> (b, 3, h, w)
|
||||||
|
|
||||||
|
# make some grids
|
||||||
|
# add all logits to one plot
|
||||||
|
logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
|
||||||
|
logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
|
||||||
|
# I just love how torchvision calls the number of columns `nrow`
|
||||||
|
grid_logits = torch.cat((logits_real, logits_fake), dim=1)
|
||||||
|
# -> (3, h, w)
|
||||||
|
|
||||||
|
grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
|
||||||
|
grid_images_fake = torchvision.utils.make_grid(
|
||||||
|
0.5 * reconstructions + 0.5, nrow=4
|
||||||
|
)
|
||||||
|
grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
|
||||||
|
# -> (3, h, w) in range [0, 1]
|
||||||
|
|
||||||
|
grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
|
||||||
|
|
||||||
|
# Create labeled colorbar
|
||||||
|
dpi = 100
|
||||||
|
height = 128 / dpi
|
||||||
|
width = grid_logits.shape[2] / dpi
|
||||||
|
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
|
||||||
|
img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
|
||||||
|
plt.colorbar(
|
||||||
|
img,
|
||||||
|
cax=ax,
|
||||||
|
orientation="horizontal",
|
||||||
|
fraction=0.9,
|
||||||
|
aspect=width / height,
|
||||||
|
pad=0.0,
|
||||||
|
)
|
||||||
|
img.set_visible(False)
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.canvas.draw()
|
||||||
|
# manually convert figure to numpy
|
||||||
|
cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
||||||
|
cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||||
|
cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
|
||||||
|
cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
|
||||||
|
|
||||||
|
# Add colorbar to plot
|
||||||
|
annotated_grid = torch.cat((grid_logits, cbar), dim=1)
|
||||||
|
blended_grid = torch.cat((grid_blend, cbar), dim=1)
|
||||||
|
return {
|
||||||
|
"vis_logits": 2 * annotated_grid[None, ...] - 1,
|
||||||
|
"vis_logits_blended": 2 * blended_grid[None, ...] - 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
def calculate_adaptive_weight(
|
||||||
|
self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||||
|
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||||
|
|
||||||
|
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||||
|
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||||
|
d_weight = d_weight * self.discriminator_weight
|
||||||
|
return d_weight
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
reconstructions: torch.Tensor,
|
||||||
|
*, # added because I changed the order here
|
||||||
|
regularization_log: Dict[str, torch.Tensor],
|
||||||
|
optimizer_idx: int,
|
||||||
|
global_step: int,
|
||||||
|
last_layer: torch.Tensor,
|
||||||
|
split: str = "train",
|
||||||
|
weights: Union[None, float, torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, dict]:
|
||||||
|
if self.scale_input_to_tgt_size:
|
||||||
|
inputs = torch.nn.functional.interpolate(
|
||||||
|
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.dims > 2:
|
||||||
|
inputs, reconstructions = map(
|
||||||
|
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
|
||||||
|
(inputs, reconstructions),
|
||||||
|
)
|
||||||
|
|
||||||
|
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
||||||
|
if self.perceptual_weight > 0:
|
||||||
|
p_loss = self.perceptual_loss(
|
||||||
|
inputs.contiguous(), reconstructions.contiguous()
|
||||||
|
)
|
||||||
|
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||||
|
|
||||||
|
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
|
||||||
|
|
||||||
|
# now the GAN part
|
||||||
|
if optimizer_idx == 0:
|
||||||
|
# generator update
|
||||||
|
if global_step >= self.discriminator_iter_start or not self.training:
|
||||||
|
logits_fake = self.discriminator(reconstructions.contiguous())
|
||||||
|
g_loss = -torch.mean(logits_fake)
|
||||||
|
if self.training:
|
||||||
|
d_weight = self.calculate_adaptive_weight(
|
||||||
|
nll_loss, g_loss, last_layer=last_layer
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
d_weight = torch.tensor(1.0)
|
||||||
|
else:
|
||||||
|
d_weight = torch.tensor(0.0)
|
||||||
|
g_loss = torch.tensor(0.0, requires_grad=True)
|
||||||
|
|
||||||
|
loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
|
||||||
|
log = dict()
|
||||||
|
for k in regularization_log:
|
||||||
|
if k in self.regularization_weights:
|
||||||
|
loss = loss + self.regularization_weights[k] * regularization_log[k]
|
||||||
|
if k in self.additional_log_keys:
|
||||||
|
log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
|
||||||
|
|
||||||
|
log.update(
|
||||||
|
{
|
||||||
|
f"{split}/loss/total": loss.clone().detach().mean(),
|
||||||
|
f"{split}/loss/nll": nll_loss.detach().mean(),
|
||||||
|
f"{split}/loss/rec": rec_loss.detach().mean(),
|
||||||
|
f"{split}/loss/g": g_loss.detach().mean(),
|
||||||
|
f"{split}/scalars/logvar": self.logvar.detach(),
|
||||||
|
f"{split}/scalars/d_weight": d_weight.detach(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss, log
|
||||||
|
elif optimizer_idx == 1:
|
||||||
|
# second pass for discriminator update
|
||||||
|
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||||
|
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
||||||
|
|
||||||
|
if global_step >= self.discriminator_iter_start or not self.training:
|
||||||
|
d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
|
||||||
|
else:
|
||||||
|
d_loss = torch.tensor(0.0, requires_grad=True)
|
||||||
|
|
||||||
|
log = {
|
||||||
|
f"{split}/loss/disc": d_loss.clone().detach().mean(),
|
||||||
|
f"{split}/logits/real": logits_real.detach().mean(),
|
||||||
|
f"{split}/logits/fake": logits_fake.detach().mean(),
|
||||||
|
}
|
||||||
|
return d_loss, log
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
|
||||||
|
|
||||||
|
def get_nll_loss(
|
||||||
|
self,
|
||||||
|
rec_loss: torch.Tensor,
|
||||||
|
weights: Optional[Union[float, torch.Tensor]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
||||||
|
weighted_nll_loss = nll_loss
|
||||||
|
if weights is not None:
|
||||||
|
weighted_nll_loss = weights * nll_loss
|
||||||
|
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
||||||
|
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||||
|
|
||||||
|
return nll_loss, weighted_nll_loss
|
||||||