mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-16 03:04:31 +01:00
soon is now
This commit is contained in:
7
sgm/modules/diffusionmodules/__init__.py
Normal file
7
sgm/modules/diffusionmodules/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .denoiser import Denoiser
|
||||
from .discretizer import Discretization
|
||||
from .loss import StandardDiffusionLoss
|
||||
from .model import Model, Encoder, Decoder
|
||||
from .openaimodel import UNetModel
|
||||
from .sampling import BaseDiffusionSampler
|
||||
from .wrappers import OpenAIWrapper
|
||||
63
sgm/modules/diffusionmodules/denoiser.py
Normal file
63
sgm/modules/diffusionmodules/denoiser.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from ...util import append_dims, instantiate_from_config
|
||||
|
||||
|
||||
class Denoiser(nn.Module):
|
||||
def __init__(self, weighting_config, scaling_config):
|
||||
super().__init__()
|
||||
|
||||
self.weighting = instantiate_from_config(weighting_config)
|
||||
self.scaling = instantiate_from_config(scaling_config)
|
||||
|
||||
def possibly_quantize_sigma(self, sigma):
|
||||
return sigma
|
||||
|
||||
def possibly_quantize_c_noise(self, c_noise):
|
||||
return c_noise
|
||||
|
||||
def w(self, sigma):
|
||||
return self.weighting(sigma)
|
||||
|
||||
def __call__(self, network, input, sigma, cond):
|
||||
sigma = self.possibly_quantize_sigma(sigma)
|
||||
sigma_shape = sigma.shape
|
||||
sigma = append_dims(sigma, input.ndim)
|
||||
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
|
||||
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
|
||||
return network(input * c_in, c_noise, cond) * c_out + input * c_skip
|
||||
|
||||
|
||||
class DiscreteDenoiser(Denoiser):
|
||||
def __init__(
|
||||
self,
|
||||
weighting_config,
|
||||
scaling_config,
|
||||
num_idx,
|
||||
discretization_config,
|
||||
do_append_zero=False,
|
||||
quantize_c_noise=True,
|
||||
flip=True,
|
||||
):
|
||||
super().__init__(weighting_config, scaling_config)
|
||||
sigmas = instantiate_from_config(discretization_config)(
|
||||
num_idx, do_append_zero=do_append_zero, flip=flip
|
||||
)
|
||||
self.register_buffer("sigmas", sigmas)
|
||||
self.quantize_c_noise = quantize_c_noise
|
||||
|
||||
def sigma_to_idx(self, sigma):
|
||||
dists = sigma - self.sigmas[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
|
||||
def idx_to_sigma(self, idx):
|
||||
return self.sigmas[idx]
|
||||
|
||||
def possibly_quantize_sigma(self, sigma):
|
||||
return self.idx_to_sigma(self.sigma_to_idx(sigma))
|
||||
|
||||
def possibly_quantize_c_noise(self, c_noise):
|
||||
if self.quantize_c_noise:
|
||||
return self.sigma_to_idx(c_noise)
|
||||
else:
|
||||
return c_noise
|
||||
31
sgm/modules/diffusionmodules/denoiser_scaling.py
Normal file
31
sgm/modules/diffusionmodules/denoiser_scaling.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
|
||||
class EDMScaling:
|
||||
def __init__(self, sigma_data=0.5):
|
||||
self.sigma_data = sigma_data
|
||||
|
||||
def __call__(self, sigma):
|
||||
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
||||
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||
c_noise = 0.25 * sigma.log()
|
||||
return c_skip, c_out, c_in, c_noise
|
||||
|
||||
|
||||
class EpsScaling:
|
||||
def __call__(self, sigma):
|
||||
c_skip = torch.ones_like(sigma, device=sigma.device)
|
||||
c_out = -sigma
|
||||
c_in = 1 / (sigma**2 + 1.0) ** 0.5
|
||||
c_noise = sigma.clone()
|
||||
return c_skip, c_out, c_in, c_noise
|
||||
|
||||
|
||||
class VScaling:
|
||||
def __call__(self, sigma):
|
||||
c_skip = 1.0 / (sigma**2 + 1.0)
|
||||
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
||||
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
||||
c_noise = sigma.clone()
|
||||
return c_skip, c_out, c_in, c_noise
|
||||
24
sgm/modules/diffusionmodules/denoiser_weighting.py
Normal file
24
sgm/modules/diffusionmodules/denoiser_weighting.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
|
||||
class UnitWeighting:
|
||||
def __call__(self, sigma):
|
||||
return torch.ones_like(sigma, device=sigma.device)
|
||||
|
||||
|
||||
class EDMWeighting:
|
||||
def __init__(self, sigma_data=0.5):
|
||||
self.sigma_data = sigma_data
|
||||
|
||||
def __call__(self, sigma):
|
||||
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
|
||||
|
||||
|
||||
class VWeighting(EDMWeighting):
|
||||
def __init__(self):
|
||||
super().__init__(sigma_data=1.0)
|
||||
|
||||
|
||||
class EpsWeighting:
|
||||
def __call__(self, sigma):
|
||||
return sigma**-2.0
|
||||
65
sgm/modules/diffusionmodules/discretizer.py
Normal file
65
sgm/modules/diffusionmodules/discretizer.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from ...util import append_zero
|
||||
from ...modules.diffusionmodules.util import make_beta_schedule
|
||||
|
||||
|
||||
class Discretization:
|
||||
def __call__(self, n, do_append_zero=True, device="cuda", flip=False):
|
||||
sigmas = self.get_sigmas(n, device)
|
||||
sigmas = append_zero(sigmas) if do_append_zero else sigmas
|
||||
return sigmas if not flip else torch.flip(sigmas, (0,))
|
||||
|
||||
|
||||
class EDMDiscretization(Discretization):
|
||||
def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
|
||||
self.sigma_min = sigma_min
|
||||
self.sigma_max = sigma_max
|
||||
self.rho = rho
|
||||
|
||||
def get_sigmas(self, n, device):
|
||||
ramp = torch.linspace(0, 1, n, device=device)
|
||||
min_inv_rho = self.sigma_min ** (1 / self.rho)
|
||||
max_inv_rho = self.sigma_max ** (1 / self.rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
|
||||
return sigmas
|
||||
|
||||
|
||||
class LegacyDDPMDiscretization(Discretization):
|
||||
def __init__(
|
||||
self,
|
||||
linear_start=0.00085,
|
||||
linear_end=0.0120,
|
||||
num_timesteps=1000,
|
||||
legacy_range=True,
|
||||
):
|
||||
self.num_timesteps = num_timesteps
|
||||
betas = make_beta_schedule(
|
||||
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
|
||||
)
|
||||
alphas = 1.0 - betas
|
||||
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
self.legacy_range = legacy_range
|
||||
|
||||
def get_sigmas(self, n, device):
|
||||
if n < self.num_timesteps:
|
||||
c = self.num_timesteps // n
|
||||
|
||||
if self.legacy_range:
|
||||
timesteps = np.asarray(list(range(0, self.num_timesteps, c)))
|
||||
timesteps += 1 # Legacy LDM Hack
|
||||
else:
|
||||
timesteps = np.asarray(list(range(0, self.num_timesteps + 1, c)))
|
||||
timesteps -= 1
|
||||
timesteps = timesteps[1:]
|
||||
|
||||
alphas_cumprod = self.alphas_cumprod[timesteps]
|
||||
else:
|
||||
alphas_cumprod = self.alphas_cumprod
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
|
||||
sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
return torch.flip(sigmas, (0,))
|
||||
53
sgm/modules/diffusionmodules/guiders.py
Normal file
53
sgm/modules/diffusionmodules/guiders.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from ...util import default, instantiate_from_config
|
||||
|
||||
|
||||
class VanillaCFG:
|
||||
"""
|
||||
implements parallelized CFG
|
||||
"""
|
||||
|
||||
def __init__(self, scale, dyn_thresh_config=None):
|
||||
scale_schedule = lambda scale, sigma: scale # independent of step
|
||||
self.scale_schedule = partial(scale_schedule, scale)
|
||||
self.dyn_thresh = instantiate_from_config(
|
||||
default(
|
||||
dyn_thresh_config,
|
||||
{
|
||||
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, x, sigma):
|
||||
x_u, x_c = x.chunk(2)
|
||||
scale_value = self.scale_schedule(sigma)
|
||||
x_pred = self.dyn_thresh(x_u, x_c, scale_value)
|
||||
return x_pred
|
||||
|
||||
def prepare_inputs(self, x, s, c, uc):
|
||||
c_out = dict()
|
||||
|
||||
for k in c:
|
||||
if k in ["vector", "crossattn", "concat"]:
|
||||
c_out[k] = torch.cat((uc[k], c[k]), 0)
|
||||
else:
|
||||
assert c[k] == uc[k]
|
||||
c_out[k] = c[k]
|
||||
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
||||
|
||||
|
||||
class IdentityGuider:
|
||||
def __call__(self, x, sigma):
|
||||
return x
|
||||
|
||||
def prepare_inputs(self, x, s, c, uc):
|
||||
c_out = dict()
|
||||
|
||||
for k in c:
|
||||
c_out[k] = c[k]
|
||||
|
||||
return x, s, c_out
|
||||
69
sgm/modules/diffusionmodules/loss.py
Normal file
69
sgm/modules/diffusionmodules/loss.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import ListConfig
|
||||
from taming.modules.losses.lpips import LPIPS
|
||||
|
||||
from ...util import append_dims, instantiate_from_config
|
||||
|
||||
|
||||
class StandardDiffusionLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sigma_sampler_config,
|
||||
type="l2",
|
||||
offset_noise_level=0.0,
|
||||
batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert type in ["l2", "l1", "lpips"]
|
||||
|
||||
self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
|
||||
|
||||
self.type = type
|
||||
self.offset_noise_level = offset_noise_level
|
||||
|
||||
if type == "lpips":
|
||||
self.lpips = LPIPS().eval()
|
||||
|
||||
if not batch2model_keys:
|
||||
batch2model_keys = []
|
||||
|
||||
if isinstance(batch2model_keys, str):
|
||||
batch2model_keys = [batch2model_keys]
|
||||
|
||||
self.batch2model_keys = set(batch2model_keys)
|
||||
|
||||
def __call__(self, network, denoiser, conditioner, input, batch):
|
||||
cond = conditioner(batch)
|
||||
additional_model_inputs = {
|
||||
key: batch[key] for key in self.batch2model_keys.intersection(batch)
|
||||
}
|
||||
|
||||
sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
|
||||
noise = torch.randn_like(input)
|
||||
if self.offset_noise_level > 0.0:
|
||||
noise = noise + self.offset_noise_level * append_dims(
|
||||
torch.randn(input.shape[0], device=input.device), input.ndim
|
||||
)
|
||||
noised_input = input + noise * append_dims(sigmas, input.ndim)
|
||||
model_output = denoiser(
|
||||
network, noised_input, sigmas, cond, **additional_model_inputs
|
||||
)
|
||||
w = append_dims(denoiser.w(sigmas), input.ndim)
|
||||
return self.get_loss(model_output, input, w)
|
||||
|
||||
def get_loss(self, model_output, target, w):
|
||||
if self.type == "l2":
|
||||
return torch.mean(
|
||||
(w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
|
||||
)
|
||||
elif self.type == "l1":
|
||||
return torch.mean(
|
||||
(w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
|
||||
)
|
||||
elif self.type == "lpips":
|
||||
loss = self.lpips(model_output, target).reshape(-1)
|
||||
return loss
|
||||
743
sgm/modules/diffusionmodules/model.py
Normal file
743
sgm/modules/diffusionmodules/model.py
Normal file
@@ -0,0 +1,743 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
print("no module 'xformers'. Processing without...")
|
||||
|
||||
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class LinAttnBlock(LinearAttention):
|
||||
"""to match AttnBlock usage"""
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def attention(self, h_: torch.Tensor) -> torch.Tensor:
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(
|
||||
lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
|
||||
)
|
||||
h_ = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v
|
||||
) # scale is dim ** -0.5 per default
|
||||
# compute attention
|
||||
|
||||
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
h_ = x
|
||||
h_ = self.attention(h_)
|
||||
h_ = self.proj_out(h_)
|
||||
return x + h_
|
||||
|
||||
|
||||
class MemoryEfficientAttnBlock(nn.Module):
|
||||
"""
|
||||
Uses xformers efficient implementation,
|
||||
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
Note: this is a single-head self-attention operation
|
||||
"""
|
||||
|
||||
#
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def attention(self, h_: torch.Tensor) -> torch.Tensor:
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(B, t.shape[1], 1, C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * 1, t.shape[1], C)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(
|
||||
q, k, v, attn_bias=None, op=self.attention_op
|
||||
)
|
||||
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(B, 1, out.shape[1], C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B, out.shape[1], C)
|
||||
)
|
||||
return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
h_ = x
|
||||
h_ = self.attention(h_)
|
||||
h_ = self.proj_out(h_)
|
||||
return x + h_
|
||||
|
||||
|
||||
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
||||
def forward(self, x, context=None, mask=None, **unused_kwargs):
|
||||
b, c, h, w = x.shape
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
out = super().forward(x, context=context, mask=mask)
|
||||
out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
|
||||
return x + out
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
assert attn_type in [
|
||||
"vanilla",
|
||||
"vanilla-xformers",
|
||||
"memory-efficient-cross-attn",
|
||||
"linear",
|
||||
"none",
|
||||
], f"attn_type {attn_type} unknown"
|
||||
if (
|
||||
version.parse(torch.__version__) < version.parse("2.0.0")
|
||||
and attn_type != "none"
|
||||
):
|
||||
assert XFORMERS_IS_AVAILABLE, (
|
||||
f"We do not support vanilla attention in {torch.__version__} anymore, "
|
||||
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||
)
|
||||
attn_type = "vanilla-xformers"
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
assert attn_kwargs is None
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "vanilla-xformers":
|
||||
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||
return MemoryEfficientAttnBlock(in_channels)
|
||||
elif type == "memory-efficient-cross-attn":
|
||||
attn_kwargs["query_dim"] = in_channels
|
||||
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
||||
elif attn_type == "none":
|
||||
return nn.Identity(in_channels)
|
||||
else:
|
||||
return LinAttnBlock(in_channels)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
use_timestep=True,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla",
|
||||
):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch * 4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.use_timestep = use_timestep
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(self.ch, self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
||||
]
|
||||
)
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
skip_in = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch * in_ch_mult[i_level]
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in + skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x, t=None, context=None):
|
||||
# assert x.shape[2] == x.shape[3] == self.resolution
|
||||
if context is not None:
|
||||
# assume aligned context, cat along channel axis
|
||||
x = torch.cat((x, context), dim=1)
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](
|
||||
torch.cat([h, hs.pop()], dim=1), temb
|
||||
)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.weight
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla",
|
||||
**ignore_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla",
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
)
|
||||
|
||||
make_attn_cls = self._make_attn()
|
||||
make_resblock_cls = self._make_resblock()
|
||||
make_conv_cls = self._make_conv()
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = make_resblock_cls(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = make_resblock_cls(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
make_resblock_cls(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn_cls(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = make_conv_cls(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def _make_attn(self) -> Callable:
|
||||
return make_attn
|
||||
|
||||
def _make_resblock(self) -> Callable:
|
||||
return ResnetBlock
|
||||
|
||||
def _make_conv(self) -> Callable:
|
||||
return torch.nn.Conv2d
|
||||
|
||||
def get_last_layer(self, **kwargs):
|
||||
return self.conv_out.weight
|
||||
|
||||
def forward(self, z, **kwargs):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, **kwargs)
|
||||
h = self.mid.attn_1(h, **kwargs)
|
||||
h = self.mid.block_2(h, temb, **kwargs)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, **kwargs)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, **kwargs)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
1262
sgm/modules/diffusionmodules/openaimodel.py
Normal file
1262
sgm/modules/diffusionmodules/openaimodel.py
Normal file
File diff suppressed because it is too large
Load Diff
365
sgm/modules/diffusionmodules/sampling.py
Normal file
365
sgm/modules/diffusionmodules/sampling.py
Normal file
@@ -0,0 +1,365 @@
|
||||
"""
|
||||
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
|
||||
"""
|
||||
|
||||
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
from ...modules.diffusionmodules.sampling_utils import (
|
||||
get_ancestral_step,
|
||||
linear_multistep_coeff,
|
||||
to_d,
|
||||
to_neg_log_sigma,
|
||||
to_sigma,
|
||||
)
|
||||
from ...util import append_dims, default, instantiate_from_config
|
||||
|
||||
DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
|
||||
|
||||
|
||||
class BaseDiffusionSampler:
|
||||
def __init__(
|
||||
self,
|
||||
discretization_config: Union[Dict, ListConfig, OmegaConf],
|
||||
num_steps: Union[int, None] = None,
|
||||
guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
|
||||
verbose: bool = False,
|
||||
device: str = "cuda",
|
||||
):
|
||||
self.num_steps = num_steps
|
||||
self.discretization = instantiate_from_config(discretization_config)
|
||||
self.guider = instantiate_from_config(
|
||||
default(
|
||||
guider_config,
|
||||
DEFAULT_GUIDER,
|
||||
)
|
||||
)
|
||||
self.verbose = verbose
|
||||
self.device = device
|
||||
|
||||
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
||||
sigmas = self.discretization(
|
||||
self.num_steps if num_steps is None else num_steps, device=self.device
|
||||
)
|
||||
uc = default(uc, cond)
|
||||
|
||||
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
num_sigmas = len(sigmas)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
return x, s_in, sigmas, num_sigmas, cond, uc
|
||||
|
||||
def denoise(self, x, denoiser, sigma, cond, uc):
|
||||
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
|
||||
denoised = self.guider(denoised, sigma)
|
||||
return denoised
|
||||
|
||||
def get_sigma_gen(self, num_sigmas):
|
||||
sigma_generator = range(num_sigmas - 1)
|
||||
if self.verbose:
|
||||
print("#" * 30, " Sampling setting ", "#" * 30)
|
||||
print(f"Sampler: {self.__class__.__name__}")
|
||||
print(f"Discretization: {self.discretization.__class__.__name__}")
|
||||
print(f"Guider: {self.guider.__class__.__name__}")
|
||||
sigma_generator = tqdm(
|
||||
sigma_generator,
|
||||
total=num_sigmas,
|
||||
desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
|
||||
)
|
||||
return sigma_generator
|
||||
|
||||
|
||||
class SingleStepDiffusionSampler(BaseDiffusionSampler):
|
||||
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def euler_step(self, x, d, dt):
|
||||
return x + dt * d
|
||||
|
||||
|
||||
class EDMSampler(SingleStepDiffusionSampler):
|
||||
def __init__(
|
||||
self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.s_churn = s_churn
|
||||
self.s_tmin = s_tmin
|
||||
self.s_tmax = s_tmax
|
||||
self.s_noise = s_noise
|
||||
|
||||
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
|
||||
sigma_hat = sigma * (gamma + 1.0)
|
||||
if gamma > 0:
|
||||
eps = torch.randn_like(x) * self.s_noise
|
||||
x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
|
||||
|
||||
denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
dt = append_dims(next_sigma - sigma_hat, x.ndim)
|
||||
|
||||
euler_step = self.euler_step(x, d, dt)
|
||||
x = self.possible_correction_step(
|
||||
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
||||
)
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
gamma = (
|
||||
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
|
||||
if self.s_tmin <= sigmas[i] <= self.s_tmax
|
||||
else 0.0
|
||||
)
|
||||
x = self.sampler_step(
|
||||
s_in * sigmas[i],
|
||||
s_in * sigmas[i + 1],
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc,
|
||||
gamma,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AncestralSampler(SingleStepDiffusionSampler):
|
||||
def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.eta = eta
|
||||
self.s_noise = s_noise
|
||||
self.noise_sampler = lambda x: torch.randn_like(x)
|
||||
|
||||
def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
|
||||
d = to_d(x, sigma, denoised)
|
||||
dt = append_dims(sigma_down - sigma, x.ndim)
|
||||
|
||||
return self.euler_step(x, d, dt)
|
||||
|
||||
def ancestral_step(self, x, sigma, next_sigma, sigma_up):
|
||||
x = torch.where(
|
||||
append_dims(next_sigma, x.ndim) > 0.0,
|
||||
x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
|
||||
x,
|
||||
)
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x = self.sampler_step(
|
||||
s_in * sigmas[i],
|
||||
s_in * sigmas[i + 1],
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LinearMultistepSampler(BaseDiffusionSampler):
|
||||
def __init__(
|
||||
self,
|
||||
order=4,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.order = order
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
ds = []
|
||||
sigmas_cpu = sigmas.detach().cpu().numpy()
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
sigma = s_in * sigmas[i]
|
||||
denoised = denoiser(
|
||||
*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
|
||||
)
|
||||
denoised = self.guider(denoised, sigma)
|
||||
d = to_d(x, sigma, denoised)
|
||||
ds.append(d)
|
||||
if len(ds) > self.order:
|
||||
ds.pop(0)
|
||||
cur_order = min(i + 1, self.order)
|
||||
coeffs = [
|
||||
linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
|
||||
for j in range(cur_order)
|
||||
]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EulerEDMSampler(EDMSampler):
|
||||
def possible_correction_step(
|
||||
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
||||
):
|
||||
return euler_step
|
||||
|
||||
|
||||
class HeunEDMSampler(EDMSampler):
|
||||
def possible_correction_step(
|
||||
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
||||
):
|
||||
if torch.sum(next_sigma) < 1e-14:
|
||||
# Save a network evaluation if all noise levels are 0
|
||||
return euler_step
|
||||
else:
|
||||
denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
|
||||
d_new = to_d(euler_step, next_sigma, denoised)
|
||||
d_prime = (d + d_new) / 2.0
|
||||
|
||||
# apply correction if noise level is not 0
|
||||
x = torch.where(
|
||||
append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class EulerAncestralSampler(AncestralSampler):
|
||||
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
|
||||
sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
|
||||
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
||||
x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
|
||||
x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DPMPP2SAncestralSampler(AncestralSampler):
|
||||
def get_variables(self, sigma, sigma_down):
|
||||
t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
|
||||
h = t_next - t
|
||||
s = t + 0.5 * h
|
||||
return h, s, t, t_next
|
||||
|
||||
def get_mult(self, h, s, t, t_next):
|
||||
mult1 = to_sigma(s) / to_sigma(t)
|
||||
mult2 = (-0.5 * h).expm1()
|
||||
mult3 = to_sigma(t_next) / to_sigma(t)
|
||||
mult4 = (-h).expm1()
|
||||
|
||||
return mult1, mult2, mult3, mult4
|
||||
|
||||
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
|
||||
sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
|
||||
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
||||
x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
|
||||
|
||||
if torch.sum(sigma_down) < 1e-14:
|
||||
# Save a network evaluation if all noise levels are 0
|
||||
x = x_euler
|
||||
else:
|
||||
h, s, t, t_next = self.get_variables(sigma, sigma_down)
|
||||
mult = [
|
||||
append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
|
||||
]
|
||||
|
||||
x2 = mult[0] * x - mult[1] * denoised
|
||||
denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
|
||||
x_dpmpp2s = mult[2] * x - mult[3] * denoised2
|
||||
|
||||
# apply correction if noise level is not 0
|
||||
x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
|
||||
|
||||
x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
|
||||
return x
|
||||
|
||||
|
||||
class DPMPP2MSampler(BaseDiffusionSampler):
|
||||
def get_variables(self, sigma, next_sigma, previous_sigma=None):
|
||||
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
|
||||
h = t_next - t
|
||||
|
||||
if previous_sigma is not None:
|
||||
h_last = t - to_neg_log_sigma(previous_sigma)
|
||||
r = h_last / h
|
||||
return h, r, t, t_next
|
||||
else:
|
||||
return h, None, t, t_next
|
||||
|
||||
def get_mult(self, h, r, t, t_next, previous_sigma):
|
||||
mult1 = to_sigma(t_next) / to_sigma(t)
|
||||
mult2 = (-h).expm1()
|
||||
|
||||
if previous_sigma is not None:
|
||||
mult3 = 1 + 1 / (2 * r)
|
||||
mult4 = 1 / (2 * r)
|
||||
return mult1, mult2, mult3, mult4
|
||||
else:
|
||||
return mult1, mult2
|
||||
|
||||
def sampler_step(
|
||||
self,
|
||||
old_denoised,
|
||||
previous_sigma,
|
||||
sigma,
|
||||
next_sigma,
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc=None,
|
||||
):
|
||||
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
||||
|
||||
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
|
||||
mult = [
|
||||
append_dims(mult, x.ndim)
|
||||
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
|
||||
]
|
||||
|
||||
x_standard = mult[0] * x - mult[1] * denoised
|
||||
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
|
||||
# Save a network evaluation if all noise levels are 0 or on the first step
|
||||
return x_standard, denoised
|
||||
else:
|
||||
denoised_d = mult[2] * denoised - mult[3] * old_denoised
|
||||
x_advanced = mult[0] * x - mult[1] * denoised_d
|
||||
|
||||
# apply correction if noise level is not 0 and not first step
|
||||
x = torch.where(
|
||||
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
|
||||
)
|
||||
|
||||
return x, denoised
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
old_denoised = None
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x, old_denoised = self.sampler_step(
|
||||
old_denoised,
|
||||
None if i == 0 else s_in * sigmas[i - 1],
|
||||
s_in * sigmas[i],
|
||||
s_in * sigmas[i + 1],
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc=uc,
|
||||
)
|
||||
|
||||
return x
|
||||
48
sgm/modules/diffusionmodules/sampling_utils.py
Normal file
48
sgm/modules/diffusionmodules/sampling_utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
from scipy import integrate
|
||||
|
||||
from ...util import append_dims
|
||||
|
||||
|
||||
class NoDynamicThresholding:
|
||||
def __call__(self, uncond, cond, scale):
|
||||
return uncond + scale * (cond - uncond)
|
||||
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
|
||||
if order - 1 > i:
|
||||
raise ValueError(f"Order {order} too high for step {i}")
|
||||
|
||||
def fn(tau):
|
||||
prod = 1.0
|
||||
for k in range(order):
|
||||
if j == k:
|
||||
continue
|
||||
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
||||
return prod
|
||||
|
||||
return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
|
||||
|
||||
|
||||
def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
|
||||
if not eta:
|
||||
return sigma_to, 0.0
|
||||
sigma_up = torch.minimum(
|
||||
sigma_to,
|
||||
eta
|
||||
* (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
|
||||
)
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
return sigma_down, sigma_up
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
return (x - denoised) / append_dims(sigma, x.ndim)
|
||||
|
||||
|
||||
def to_neg_log_sigma(sigma):
|
||||
return sigma.log().neg()
|
||||
|
||||
|
||||
def to_sigma(neg_log_sigma):
|
||||
return neg_log_sigma.neg().exp()
|
||||
31
sgm/modules/diffusionmodules/sigma_sampling.py
Normal file
31
sgm/modules/diffusionmodules/sigma_sampling.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
from ...util import default, instantiate_from_config
|
||||
|
||||
|
||||
class EDMSampling:
|
||||
def __init__(self, p_mean=-1.2, p_std=1.2):
|
||||
self.p_mean = p_mean
|
||||
self.p_std = p_std
|
||||
|
||||
def __call__(self, n_samples, rand=None):
|
||||
log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
|
||||
return log_sigma.exp()
|
||||
|
||||
|
||||
class DiscreteSampling:
|
||||
def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True):
|
||||
self.num_idx = num_idx
|
||||
self.sigmas = instantiate_from_config(discretization_config)(
|
||||
num_idx, do_append_zero=do_append_zero, flip=flip
|
||||
)
|
||||
|
||||
def idx_to_sigma(self, idx):
|
||||
return self.sigmas[idx]
|
||||
|
||||
def __call__(self, n_samples, rand=None):
|
||||
idx = default(
|
||||
rand,
|
||||
torch.randint(0, self.num_idx, (n_samples,)),
|
||||
)
|
||||
return self.idx_to_sigma(idx)
|
||||
308
sgm/modules/diffusionmodules/util.py
Normal file
308
sgm/modules/diffusionmodules/util.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
adopted from
|
||||
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
and
|
||||
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
and
|
||||
https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
||||
|
||||
thanks!
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import repeat
|
||||
|
||||
|
||||
def make_beta_schedule(
|
||||
schedule,
|
||||
n_timestep,
|
||||
linear_start=1e-4,
|
||||
linear_end=2e-2,
|
||||
):
|
||||
if schedule == "linear":
|
||||
betas = (
|
||||
torch.linspace(
|
||||
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
|
||||
)
|
||||
** 2
|
||||
)
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def mixed_checkpoint(func, inputs: dict, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
|
||||
borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
|
||||
it also works with non-tensor inputs
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument dictionary to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
|
||||
tensor_inputs = [
|
||||
inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)
|
||||
]
|
||||
non_tensor_keys = [
|
||||
key for key in inputs if not isinstance(inputs[key], torch.Tensor)
|
||||
]
|
||||
non_tensor_inputs = [
|
||||
inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)
|
||||
]
|
||||
args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
|
||||
return MixedCheckpointFunction.apply(
|
||||
func,
|
||||
len(tensor_inputs),
|
||||
len(non_tensor_inputs),
|
||||
tensor_keys,
|
||||
non_tensor_keys,
|
||||
*args,
|
||||
)
|
||||
else:
|
||||
return func(**inputs)
|
||||
|
||||
|
||||
class MixedCheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
run_function,
|
||||
length_tensors,
|
||||
length_non_tensors,
|
||||
tensor_keys,
|
||||
non_tensor_keys,
|
||||
*args,
|
||||
):
|
||||
ctx.end_tensors = length_tensors
|
||||
ctx.end_non_tensors = length_tensors + length_non_tensors
|
||||
ctx.gpu_autocast_kwargs = {
|
||||
"enabled": torch.is_autocast_enabled(),
|
||||
"dtype": torch.get_autocast_gpu_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled(),
|
||||
}
|
||||
assert (
|
||||
len(tensor_keys) == length_tensors
|
||||
and len(non_tensor_keys) == length_non_tensors
|
||||
)
|
||||
|
||||
ctx.input_tensors = {
|
||||
key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))
|
||||
}
|
||||
ctx.input_non_tensors = {
|
||||
key: val
|
||||
for (key, val) in zip(
|
||||
non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])
|
||||
)
|
||||
}
|
||||
ctx.run_function = run_function
|
||||
ctx.input_params = list(args[ctx.end_non_tensors :])
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(
|
||||
**ctx.input_tensors, **ctx.input_non_tensors
|
||||
)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
# additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
|
||||
ctx.input_tensors = {
|
||||
key: ctx.input_tensors[key].detach().requires_grad_(True)
|
||||
for key in ctx.input_tensors
|
||||
}
|
||||
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = {
|
||||
key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])
|
||||
for key in ctx.input_tensors
|
||||
}
|
||||
# shallow_copies.update(additional_args)
|
||||
output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
|
||||
input_grads = torch.autograd.grad(
|
||||
output_tensors,
|
||||
list(ctx.input_tensors.values()) + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (
|
||||
(None, None, None, None, None)
|
||||
+ input_grads[: ctx.end_tensors]
|
||||
+ (None,) * (ctx.end_non_tensors - ctx.end_tensors)
|
||||
+ input_grads[ctx.end_tensors :]
|
||||
)
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
ctx.gpu_autocast_kwargs = {
|
||||
"enabled": torch.is_autocast_enabled(),
|
||||
"dtype": torch.get_autocast_gpu_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled(),
|
||||
}
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
||||
output_tensors = ctx.run_function(*shallow_copies)
|
||||
input_grads = torch.autograd.grad(
|
||||
output_tensors,
|
||||
ctx.input_tensors + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (None, None) + input_grads
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half
|
||||
).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
else:
|
||||
embedding = repeat(timesteps, "b -> b d", d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
34
sgm/modules/diffusionmodules/wrappers.py
Normal file
34
sgm/modules/diffusionmodules/wrappers.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
|
||||
OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
|
||||
|
||||
|
||||
class IdentityWrapper(nn.Module):
|
||||
def __init__(self, diffusion_model, compile_model: bool = False):
|
||||
super().__init__()
|
||||
compile = (
|
||||
torch.compile
|
||||
if (version.parse(torch.__version__) >= version.parse("2.0.0"))
|
||||
and compile_model
|
||||
else lambda x: x
|
||||
)
|
||||
self.diffusion_model = compile(diffusion_model)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.diffusion_model(*args, **kwargs)
|
||||
|
||||
|
||||
class OpenAIWrapper(IdentityWrapper):
|
||||
def forward(
|
||||
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
|
||||
) -> torch.Tensor:
|
||||
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
|
||||
return self.diffusion_model(
|
||||
x,
|
||||
timesteps=t,
|
||||
context=c.get("crossattn", None),
|
||||
y=c.get("vector", None),
|
||||
**kwargs
|
||||
)
|
||||
Reference in New Issue
Block a user