Pre release changes for production (#59)

* clean requirements

* rm taming deps

* isort, black

* mv lipips, license

* clean vq, fix path

* fix loss path, gitignore

* tested requirements pt13

* fix numpy req for python3.8, add tests

* fix name

* fix dep scipy 3.8 pt2

* add black test formatter
This commit is contained in:
Benjamin Aubin
2023-07-26 12:09:28 +02:00
committed by GitHub
parent 4a3f0f546e
commit e596332148
31 changed files with 642 additions and 128 deletions

15
.github/workflows/black.yml vendored Normal file
View File

@@ -0,0 +1,15 @@
name: Run black
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install venv
run: |
sudo apt-get -y install python3.10-venv
- uses: psf/black@stable
with:
options: "--check --verbose -l88"
src: "./sgm ./scripts ./main.py"

26
.github/workflows/test-build.yaml vendored Normal file
View File

@@ -0,0 +1,26 @@
name: Build package
on:
push:
pull_request:
jobs:
build:
name: Build
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.10"]
requirements-file: ["pt2", "pt13"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements/${{ matrix.requirements-file }}.txt
pip install .

9
.gitignore vendored
View File

@@ -1,9 +1,14 @@
# extensions
*.egg-info
*.py[cod]
# envs
.pt13
.pt2
.pt2_2
# directories
/checkpoints
/dist
/outputs
build
/build
/src

View File

@@ -59,10 +59,9 @@ This is assuming you have navigated to the `generative-models` root after clonin
```shell
# install required packages from pypi
python3 -m venv .pt1
source .pt1/bin/activate
pip3 install wheel
pip3 install -r requirements_pt13.txt
python3 -m venv .pt13
source .pt13/bin/activate
pip3 install -r requirements/pt13.txt
```
**PyTorch 2.0**
@@ -72,8 +71,20 @@ pip3 install -r requirements_pt13.txt
# install required packages from pypi
python3 -m venv .pt2
source .pt2/bin/activate
pip3 install wheel
pip3 install -r requirements_pt2.txt
pip3 install -r requirements/pt2.txt
```
#### 3. Install `sgm`
```shell
pip3 install .
```
#### 4. Install `sdata` for training
```shell
pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
```
## Packaging

11
main.py
View File

@@ -12,22 +12,18 @@ import pytorch_lightning as pl
import torch
import torchvision
import wandb
from PIL import Image
from matplotlib import pyplot as plt
from natsort import natsorted
from omegaconf import OmegaConf
from packaging import version
from PIL import Image
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_only
from sgm.util import (
exists,
instantiate_from_config,
isheatmap,
)
from sgm.util import exists, instantiate_from_config, isheatmap
MULTINODE_HACKS = True
@@ -910,11 +906,12 @@ if __name__ == "__main__":
trainer.test(model, data)
except RuntimeError as err:
if MULTINODE_HACKS:
import requests
import datetime
import os
import socket
import requests
device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
hostname = socket.gethostname()
ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")

40
requirements/pt13.txt Normal file
View File

@@ -0,0 +1,40 @@
black==23.7.0
chardet>=5.1.0
clip @ git+https://github.com/openai/CLIP.git
einops>=0.6.1
fairscale>=0.4.13
fire>=0.5.0
fsspec>=2023.6.0
invisible-watermark>=0.2.0
kornia==0.6.9
matplotlib>=3.7.2
natsort>=8.4.0
numpy>=1.24.4
omegaconf>=2.3.0
onnx<=1.12.0
open-clip-torch>=2.20.0
opencv-python==4.6.0.66
pandas>=2.0.3
pillow>=9.5.0
pudb>=2022.1.3
pytorch-lightning==1.8.5
pyyaml>=6.0.1
scipy>=1.10.1
streamlit>=1.25.0
tensorboardx==2.5.1
timm>=0.9.2
tokenizers==0.12.1
--extra-index-url https://download.pytorch.org/whl/cu117
torch==1.13.1+cu117
torchaudio==0.13.1
torchdata==0.5.1
torchmetrics>=1.0.1
torchvision==0.14.1+cu117
tqdm>=4.65.0
transformers==4.19.1
triton==2.0.0.post1
urllib3<1.27,>=1.25.4
wandb>=0.15.6
webdataset>=0.2.33
wheel>=0.41.0
xformers==0.0.16

39
requirements/pt2.txt Normal file
View File

@@ -0,0 +1,39 @@
black==23.7.0
chardet==5.1.0
clip @ git+https://github.com/openai/CLIP.git
einops>=0.6.1
fairscale>=0.4.13
fire>=0.5.0
fsspec>=2023.6.0
invisible-watermark>=0.2.0
kornia==0.6.9
matplotlib>=3.7.2
natsort>=8.4.0
ninja>=1.11.1
numpy>=1.24.4
omegaconf>=2.3.0
open-clip-torch>=2.20.0
opencv-python==4.6.0.66
pandas>=2.0.3
pillow>=9.5.0
pudb>=2022.1.3
pytorch-lightning==2.0.1
pyyaml>=6.0.1
scipy>=1.10.1
streamlit>=0.73.1
tensorboardx==2.6
timm>=0.9.2
tokenizers==0.12.1
torch>=2.0.1
torchaudio>=2.0.2
torchdata==0.6.1
torchmetrics>=1.0.1
torchvision>=0.15.2
tqdm>=4.65.0
transformers==4.19.1
triton==2.0.0
urllib3<1.27,>=1.25.4
wandb>=0.15.6
webdataset>=0.2.33
wheel>=0.41.0
xformers>=0.0.20

View File

@@ -1,41 +0,0 @@
omegaconf
einops
fire
tqdm
pillow
numpy
webdataset>=0.2.33
--extra-index-url https://download.pytorch.org/whl/cu117
torch==1.13.1+cu117
xformers==0.0.16
torchaudio==0.13.1
torchvision==0.14.1+cu117
torchmetrics
opencv-python==4.6.0.66
fairscale
pytorch-lightning==1.8.5
fsspec
kornia==0.6.9
matplotlib
natsort
tensorboardx==2.5.1
open-clip-torch
chardet
scipy
pandas
pudb
pyyaml
urllib3<1.27,>=1.25.4
streamlit>=0.73.1
timm
tokenizers==0.12.1
torchdata==0.5.1
transformers==4.19.1
onnx<=1.12.0
triton
wandb
invisible-watermark
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
-e git+https://github.com/openai/CLIP.git@main#egg=clip
-e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
-e .

View File

@@ -1,41 +0,0 @@
omegaconf
einops
fire
tqdm
pillow
numpy
webdataset>=0.2.33
ninja
torch
matplotlib
torchaudio>=2.0.2
torchmetrics
torchvision>=0.15.2
opencv-python==4.6.0.66
fairscale
pytorch-lightning==2.0.1
fire
fsspec
kornia==0.6.9
natsort
open-clip-torch
chardet==5.1.0
tensorboardx==2.6
pandas
pudb
pyyaml
urllib3<1.27,>=1.25.4
scipy
streamlit>=0.73.1
timm
tokenizers==0.12.1
transformers==4.19.1
triton==2.0.0
torchdata==0.6.1
wandb
invisible-watermark
xformers
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
-e git+https://github.com/openai/CLIP.git@main#egg=clip
-e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
-e .

View File

@@ -1,4 +1,5 @@
from pytorch_lightning import seed_everything
from scripts.demo.streamlit_helpers import *
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering

View File

@@ -1,29 +1,28 @@
import os
from typing import Union, List
import math
import os
from typing import List, Union
import numpy as np
import streamlit as st
import torch
from PIL import Image
from einops import rearrange, repeat
from imwatermark import WatermarkEncoder
from omegaconf import OmegaConf, ListConfig
from omegaconf import ListConfig, OmegaConf
from PIL import Image
from safetensors.torch import load_file as load_safetensors
from torch import autocast
from torchvision import transforms
from torchvision.utils import make_grid
from safetensors.torch import load_file as load_safetensors
from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler,
HeunEDMSampler,
EulerAncestralSampler,
DPMPP2SAncestralSampler,
DPMPP2MSampler,
LinearMultistepSampler,
)
from sgm.util import append_dims
from sgm.util import instantiate_from_config
from sgm.util import append_dims, instantiate_from_config
class WatermarkEmbedder:

View File

@@ -1,9 +1,10 @@
import os
import torch
import clip
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
import clip
RESOURCES_ROOT = "scripts/util/detection/"

View File

@@ -1,5 +1,4 @@
from .data import StableDataModuleFromConfig
from .models import AutoencodingEngine, DiffusionEngine
from .util import instantiate_from_config, get_configs_path
from .util import get_configs_path, instantiate_from_config
__version__ = "0.0.1"
__version__ = "0.1.0"

View File

@@ -1,7 +1,7 @@
import torchvision
import pytorch_lightning as pl
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
class CIFAR10DataDictWrapper(Dataset):

View File

@@ -1,7 +1,7 @@
import torchvision
import pytorch_lightning as pl
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
class MNISTDataDictWrapper(Dataset):

View File

@@ -3,11 +3,11 @@ from typing import Any, Union
import torch
import torch.nn as nn
from einops import rearrange
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
from ....util import default, instantiate_from_config
from ..lpips.loss.lpips import LPIPS
from ..lpips.model.model import NLayerDiscriminator, weights_init
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0):

View File

@@ -0,0 +1 @@
vgg.pth

View File

@@ -0,0 +1,23 @@
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,147 @@
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
from collections import namedtuple
import torch
import torch.nn as nn
from torchvision import models
from ..util import get_ckpt_path
class LPIPS(nn.Module):
# Learned perceptual metric
def __init__(self, use_dropout=True):
super().__init__()
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False
def load_from_pretrained(self, name="vgg_lpips"):
ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
self.load_state_dict(
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
)
print("loaded pretrained LPIPS loss from {}".format(ckpt))
@classmethod
def from_pretrained(cls, name="vgg_lpips"):
if name != "vgg_lpips":
raise NotImplementedError
model = cls()
ckpt = get_ckpt_path(name)
model.load_state_dict(
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
)
return model
def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
outs1[kk]
)
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
res = [
spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
for kk in range(len(self.chns))
]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer(
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
)
self.register_buffer(
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
)
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
"""A single linear layer which does a 1x1 conv"""
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = (
[
nn.Dropout(),
]
if (use_dropout)
else []
)
layers += [
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
]
self.model = nn.Sequential(*layers)
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
)
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
def normalize_tensor(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
return x / (norm_factor + eps)
def spatial_average(x, keepdim=True):
return x.mean([2, 3], keepdim=keepdim)

View File

@@ -0,0 +1,58 @@
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------- LICENSE FOR pix2pix --------------------------------
BSD License
For pix2pix software
Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
----------------------------- LICENSE FOR DCGAN --------------------------------
BSD License
For dcgan.torch software
Copyright (c) 2015, Facebook, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,88 @@
import functools
import torch.nn as nn
from ..util import ActNorm
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator as in Pix2Pix
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
if not use_actnorm:
norm_layer = nn.BatchNorm2d
else:
norm_layer = ActNorm
if (
type(norm_layer) == functools.partial
): # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
kw = 4
padw = 1
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True),
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=2,
padding=padw,
bias=use_bias,
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=1,
padding=padw,
bias=use_bias,
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]
sequence += [
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.main(input)

View File

@@ -0,0 +1,128 @@
import hashlib
import os
import requests
import torch
import torch.nn as nn
from tqdm import tqdm
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
CKPT_MAP = {"vgg_lpips": "vgg.pth"}
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
with requests.get(url, stream=True) as r:
total_size = int(r.headers.get("content-length", 0))
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
with open(local_path, "wb") as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)
def md5_hash(path):
with open(path, "rb") as f:
content = f.read()
return hashlib.md5(content).hexdigest()
def get_ckpt_path(name, root, check=False):
assert name in URL_MAP
path = os.path.join(root, CKPT_MAP[name])
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
download(URL_MAP[name], path)
md5 = md5_hash(path)
assert md5 == MD5_MAP[name], md5
return path
class ActNorm(nn.Module):
def __init__(
self, num_features, logdet=False, affine=True, allow_reverse_init=False
):
assert affine
super().__init__()
self.logdet = logdet
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
self.allow_reverse_init = allow_reverse_init
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
def initialize(self, input):
with torch.no_grad():
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
mean = (
flatten.mean(1)
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.permute(1, 0, 2, 3)
)
std = (
flatten.std(1)
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.permute(1, 0, 2, 3)
)
self.loc.data.copy_(-mean)
self.scale.data.copy_(1 / (std + 1e-6))
def forward(self, input, reverse=False):
if reverse:
return self.reverse(input)
if len(input.shape) == 2:
input = input[:, :, None, None]
squeeze = True
else:
squeeze = False
_, _, height, width = input.shape
if self.training and self.initialized.item() == 0:
self.initialize(input)
self.initialized.fill_(1)
h = self.scale * (input + self.loc)
if squeeze:
h = h.squeeze(-1).squeeze(-1)
if self.logdet:
log_abs = torch.log(torch.abs(self.scale))
logdet = height * width * torch.sum(log_abs)
logdet = logdet * torch.ones(input.shape[0]).to(input)
return h, logdet
return h
def reverse(self, output):
if self.training and self.initialized.item() == 0:
if not self.allow_reverse_init:
raise RuntimeError(
"Initializing ActNorm in reverse direction is "
"disabled by default. Use allow_reverse_init=True to enable."
)
else:
self.initialize(output)
self.initialized.fill_(1)
if len(output.shape) == 2:
output = output[:, :, None, None]
squeeze = True
else:
squeeze = False
h = output / self.scale - self.loc
if squeeze:
h = h.squeeze(-1).squeeze(-1)
return h

View File

@@ -0,0 +1,17 @@
import torch
import torch.nn.functional as F
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real))
+ torch.mean(torch.nn.functional.softplus(logits_fake))
)
return d_loss

View File

@@ -1,7 +1,7 @@
from .denoiser import Denoiser
from .discretizer import Discretization
from .loss import StandardDiffusionLoss
from .model import Model, Encoder, Decoder
from .model import Decoder, Encoder, Model
from .openaimodel import UNetModel
from .sampling import BaseDiffusionSampler
from .wrappers import OpenAIWrapper

View File

@@ -1,10 +1,11 @@
import torch
import numpy as np
from functools import partial
from abc import abstractmethod
from functools import partial
import numpy as np
import torch
from ...util import append_zero
from ...modules.diffusionmodules.util import make_beta_schedule
from ...util import append_zero
def generate_roughly_equally_spaced_steps(

View File

@@ -3,9 +3,9 @@ from typing import List, Optional, Union
import torch
import torch.nn as nn
from omegaconf import ListConfig
from taming.modules.losses.lpips import LPIPS
from ...util import append_dims, instantiate_from_config
from ...modules.autoencoding.lpips.loss.lpips import LPIPS
class StandardDiffusionLoss(nn.Module):

View File

@@ -30,5 +30,5 @@ class OpenAIWrapper(IdentityWrapper):
timesteps=t,
context=c.get("crossattn", None),
y=c.get("vector", None),
**kwargs
**kwargs,
)

View File

@@ -1,5 +1,5 @@
import torch
import numpy as np
import torch
class AbstractDistribution: