mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5d27029e98 | ||
|
|
3115fa17b3 | ||
|
|
124d8577c8 | ||
|
|
2db0c9794c | ||
|
|
2277b47ffd | ||
|
|
28b58e568c | ||
|
|
924455d97d | ||
|
|
6021945fc8 | ||
|
|
6f76652d11 | ||
|
|
3dda2570ed | ||
|
|
2f3c02dba8 | ||
|
|
908088cfea | ||
|
|
8dc8a3de0d | ||
|
|
35f89556ba | ||
|
|
2b55f753b9 |
16
README.md
16
README.md
@@ -508,7 +508,7 @@ To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it i
|
||||
import torch
|
||||
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
|
||||
|
||||
# openai pretrained clip - defaults to ViT/B-32
|
||||
# openai pretrained clip - defaults to ViT-B/32
|
||||
|
||||
clip = OpenAIClipAdapter()
|
||||
|
||||
@@ -1002,12 +1002,13 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
|
||||
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
|
||||
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
||||
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||
- [x] cross embed layers for downsampling, as an option
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
@@ -1093,4 +1094,15 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{wang2021crossformer,
|
||||
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
|
||||
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
|
||||
year = {2021},
|
||||
eprint = {2108.00154},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from inspect import isfunction
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
@@ -33,6 +33,10 @@ from rotary_embedding_torch import RotaryEmbedding
|
||||
from x_clip import CLIP
|
||||
from coca_pytorch import CoCa
|
||||
|
||||
# constants
|
||||
|
||||
NAT = 1. / math.log(2.)
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
@@ -41,6 +45,14 @@ def exists(val):
|
||||
def identity(t, *args, **kwargs):
|
||||
return t
|
||||
|
||||
def maybe(fn):
|
||||
@wraps(fn)
|
||||
def inner(x):
|
||||
if not exists(x):
|
||||
return x
|
||||
return fn(x)
|
||||
return inner
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
@@ -91,6 +103,9 @@ def freeze_model_and_make_eval_(model):
|
||||
|
||||
# tensor helpers
|
||||
|
||||
def log(t, eps = 1e-12):
|
||||
return torch.log(t.clamp(min = eps))
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1)
|
||||
|
||||
@@ -107,10 +122,10 @@ def resize_image_to(image, target_image_size):
|
||||
# ddpms expect images to be in the range of -1 to 1
|
||||
# but CLIP may otherwise
|
||||
|
||||
def normalize_img(img):
|
||||
def normalize_neg_one_to_one(img):
|
||||
return img * 2 - 1
|
||||
|
||||
def unnormalize_img(normed_img):
|
||||
def unnormalize_zero_to_one(normed_img):
|
||||
return (normed_img + 1) * 0.5
|
||||
|
||||
# clip related adapters
|
||||
@@ -271,7 +286,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
def embed_image(self, image):
|
||||
assert not self.cleared
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.clip_normalize(unnormalize_img(image))
|
||||
image = self.clip_normalize(image)
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||
|
||||
@@ -297,13 +312,43 @@ def noise_like(shape, device, repeat=False):
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
def meanflat(x):
|
||||
return x.mean(dim = tuple(range(1, len(x.shape))))
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))
|
||||
|
||||
def approx_standard_normal_cdf(x):
|
||||
return 0.5 * (1.0 + torch.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3))))
|
||||
|
||||
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
|
||||
assert x.shape == means.shape == log_scales.shape
|
||||
|
||||
centered_x = x - means
|
||||
inv_stdv = torch.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_x + 1. / 255.)
|
||||
cdf_plus = approx_standard_normal_cdf(plus_in)
|
||||
min_in = inv_stdv * (centered_x - 1. / 255.)
|
||||
cdf_min = approx_standard_normal_cdf(min_in)
|
||||
log_cdf_plus = log(cdf_plus)
|
||||
log_one_minus_cdf_min = log(1. - cdf_min)
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
|
||||
log_probs = torch.where(x < -thres,
|
||||
log_cdf_plus,
|
||||
torch.where(x > thres,
|
||||
log_one_minus_cdf_min,
|
||||
log(cdf_delta)))
|
||||
|
||||
return log_probs
|
||||
|
||||
def cosine_beta_schedule(timesteps, s = 0.008):
|
||||
"""
|
||||
cosine schedule
|
||||
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||
"""
|
||||
steps = timesteps + 1
|
||||
x = torch.linspace(0, timesteps, steps)
|
||||
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
|
||||
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
@@ -314,21 +359,21 @@ def linear_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return torch.linspace(beta_start, beta_end, timesteps)
|
||||
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
|
||||
|
||||
|
||||
def quadratic_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2
|
||||
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
|
||||
|
||||
|
||||
def sigmoid_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
betas = torch.linspace(-6, 6, timesteps)
|
||||
betas = torch.linspace(-6, 6, timesteps, dtype = torch.float64)
|
||||
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
||||
|
||||
|
||||
@@ -368,17 +413,21 @@ class BaseGaussianDiffusion(nn.Module):
|
||||
self.loss_type = loss_type
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
self.register_buffer('betas', betas)
|
||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
||||
# register buffer helper function to cast double back to float
|
||||
|
||||
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
|
||||
|
||||
register_buffer('betas', betas)
|
||||
register_buffer('alphas_cumprod', alphas_cumprod)
|
||||
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
|
||||
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
||||
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
||||
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
||||
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
||||
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
||||
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
|
||||
@@ -386,19 +435,13 @@ class BaseGaussianDiffusion(nn.Module):
|
||||
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
|
||||
self.register_buffer('posterior_variance', posterior_variance)
|
||||
register_buffer('posterior_variance', posterior_variance)
|
||||
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
|
||||
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
||||
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||
return mean, variance, log_variance
|
||||
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
||||
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
@@ -827,7 +870,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
image_channels = 3,
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.,
|
||||
loss_type = "l1",
|
||||
loss_type = "l2",
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
@@ -1228,6 +1271,33 @@ class LinearAttention(nn.Module):
|
||||
out = self.nonlin(out)
|
||||
return self.to_out(out)
|
||||
|
||||
class CrossEmbedLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
kernel_sizes,
|
||||
dim_out = None,
|
||||
stride = 2
|
||||
):
|
||||
super().__init__()
|
||||
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
|
||||
dim_out = default(dim_out, dim_in)
|
||||
|
||||
kernel_sizes = sorted(kernel_sizes)
|
||||
num_scales = len(kernel_sizes)
|
||||
|
||||
# calculate the dimension at each scale
|
||||
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
|
||||
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
|
||||
|
||||
self.convs = nn.ModuleList([])
|
||||
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
|
||||
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
|
||||
|
||||
def forward(self, x):
|
||||
fmaps = tuple(map(lambda conv: conv(x), self.convs))
|
||||
return torch.cat(fmaps, dim = 1)
|
||||
|
||||
class Unet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1241,6 +1311,7 @@ class Unet(nn.Module):
|
||||
out_dim = None,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
channels = 3,
|
||||
channels_out = None,
|
||||
attn_dim_head = 32,
|
||||
attn_heads = 16,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
@@ -1252,6 +1323,9 @@ class Unet(nn.Module):
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7,
|
||||
resnet_groups = 8,
|
||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||
cross_embed_downsample = False,
|
||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1268,12 +1342,12 @@ class Unet(nn.Module):
|
||||
# determine dimensions
|
||||
|
||||
self.channels = channels
|
||||
self.channels_out = default(channels_out, channels)
|
||||
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
init_dim = default(init_dim, dim // 2)
|
||||
init_dim = default(init_dim, dim // 3 * 2)
|
||||
|
||||
assert (init_conv_kernel_size % 2) == 1
|
||||
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
||||
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
||||
|
||||
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
@@ -1333,6 +1407,12 @@ class Unet(nn.Module):
|
||||
|
||||
assert len(resnet_groups) == len(in_out)
|
||||
|
||||
# downsample klass
|
||||
|
||||
downsample_klass = Downsample
|
||||
if cross_embed_downsample:
|
||||
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
||||
|
||||
# layers
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
@@ -1348,7 +1428,7 @@ class Unet(nn.Module):
|
||||
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
downsample_klass(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
@@ -1368,11 +1448,9 @@ class Unet(nn.Module):
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
|
||||
out_dim = default(out_dim, channels)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
||||
nn.Conv2d(dim, out_dim, 1)
|
||||
nn.Conv2d(dim, self.channels_out, 1)
|
||||
)
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
@@ -1382,13 +1460,25 @@ class Unet(nn.Module):
|
||||
*,
|
||||
lowres_cond,
|
||||
channels,
|
||||
channels_out,
|
||||
cond_on_image_embeds,
|
||||
cond_on_text_encodings
|
||||
):
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds and cond_on_text_encodings == self.cond_on_text_encodings:
|
||||
if lowres_cond == self.lowres_cond and \
|
||||
channels == self.channels and \
|
||||
cond_on_image_embeds == self.cond_on_image_embeds and \
|
||||
cond_on_text_encodings == self.cond_on_text_encodings and \
|
||||
channels_out == self.channels_out:
|
||||
return self
|
||||
|
||||
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds, 'cond_on_text_encodings': cond_on_text_encodings}
|
||||
updated_kwargs = dict(
|
||||
lowres_cond = lowres_cond,
|
||||
channels = channels,
|
||||
channels_out = channels_out,
|
||||
cond_on_image_embeds = cond_on_image_embeds,
|
||||
cond_on_text_encodings = cond_on_text_encodings
|
||||
)
|
||||
|
||||
return self.__class__(**{**self._locals, **updated_kwargs})
|
||||
|
||||
def forward_with_cond_scale(
|
||||
@@ -1453,11 +1543,12 @@ class Unet(nn.Module):
|
||||
|
||||
if self.cond_on_image_embeds:
|
||||
image_tokens = self.image_to_cond(image_embed)
|
||||
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
||||
|
||||
image_tokens = torch.where(
|
||||
image_keep_mask,
|
||||
image_tokens,
|
||||
self.null_image_embed
|
||||
null_image_embed
|
||||
)
|
||||
|
||||
# take care of text encodings (optional)
|
||||
@@ -1481,10 +1572,12 @@ class Unet(nn.Module):
|
||||
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
||||
text_keep_mask = text_mask & text_keep_mask
|
||||
|
||||
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
|
||||
|
||||
text_tokens = torch.where(
|
||||
text_keep_mask,
|
||||
text_tokens,
|
||||
self.null_text_embed
|
||||
null_text_embed
|
||||
)
|
||||
|
||||
# main conditioning tokens (c)
|
||||
@@ -1572,7 +1665,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
timesteps = 1000,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
loss_type = 'l1',
|
||||
loss_type = 'l2',
|
||||
beta_schedule = 'cosine',
|
||||
predict_x_start = False,
|
||||
predict_x_start_for_latent_diffusion = False,
|
||||
@@ -1585,6 +1678,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
clip_denoised = True,
|
||||
clip_x_start = True,
|
||||
clip_adapter_overrides = dict(),
|
||||
learned_variance = True,
|
||||
vb_loss_weight = 0.001,
|
||||
unconditional = False
|
||||
):
|
||||
super().__init__(
|
||||
@@ -1623,10 +1718,18 @@ class Decoder(BaseGaussianDiffusion):
|
||||
unets = cast_tuple(unet)
|
||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
|
||||
|
||||
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
|
||||
|
||||
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
|
||||
self.learned_variance = learned_variance
|
||||
self.vb_loss_weight = vb_loss_weight
|
||||
|
||||
# construct unets and vaes
|
||||
|
||||
self.unets = nn.ModuleList([])
|
||||
self.vaes = nn.ModuleList([])
|
||||
|
||||
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
|
||||
for ind, (one_unet, one_vae, one_unet_learned_var) in enumerate(zip(unets, vaes, learned_variance)):
|
||||
assert isinstance(one_unet, Unet)
|
||||
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
|
||||
|
||||
@@ -1634,12 +1737,14 @@ class Decoder(BaseGaussianDiffusion):
|
||||
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
|
||||
|
||||
unet_channels = default(latent_dim, self.channels)
|
||||
unet_channels_out = unet_channels * (1 if not one_unet_learned_var else 2)
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
cond_on_image_embeds = is_first and not unconditional,
|
||||
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
|
||||
channels = unet_channels
|
||||
channels = unet_channels,
|
||||
channels_out = unet_channels_out
|
||||
)
|
||||
|
||||
self.unets.append(one_unet)
|
||||
@@ -1702,8 +1807,11 @@ class Decoder(BaseGaussianDiffusion):
|
||||
yield
|
||||
unet.cpu()
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
|
||||
|
||||
if learned_variance:
|
||||
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
|
||||
|
||||
if predict_x_start:
|
||||
x_recon = pred
|
||||
@@ -1714,24 +1822,38 @@ class Decoder(BaseGaussianDiffusion):
|
||||
x_recon.clamp_(-1., 1.)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
|
||||
if learned_variance:
|
||||
# if learned variance, posterio variance and posterior log variance are predicted by the network
|
||||
# by an interpolation of the max and min log beta values
|
||||
# eq 15 - https://arxiv.org/abs/2102.09672
|
||||
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
|
||||
max_log = extract(torch.log(self.betas), t, x.shape)
|
||||
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
||||
|
||||
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
||||
posterior_variance = posterior_log_variance.exp()
|
||||
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.inference_mode()
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.inference_mode()
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
img = self.p_sample(
|
||||
unet,
|
||||
@@ -1743,17 +1865,26 @@ class Decoder(BaseGaussianDiffusion):
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
predict_x_start = predict_x_start,
|
||||
learned_variance = learned_variance,
|
||||
clip_denoised = clip_denoised
|
||||
)
|
||||
|
||||
return img
|
||||
unnormalize_img = unnormalize_zero_to_one(img)
|
||||
return unnormalize_img
|
||||
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None):
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
# normalize to [-1, 1]
|
||||
|
||||
x_start = normalize_neg_one_to_one(x_start)
|
||||
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||
|
||||
# get x_t
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
||||
|
||||
pred = unet(
|
||||
model_output = unet(
|
||||
x_noisy,
|
||||
times,
|
||||
image_embed = image_embed,
|
||||
@@ -1764,10 +1895,48 @@ class Decoder(BaseGaussianDiffusion):
|
||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||
)
|
||||
|
||||
if learned_variance:
|
||||
pred, _ = model_output.chunk(2, dim = 1)
|
||||
else:
|
||||
pred = model_output
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
if not learned_variance:
|
||||
# return simple loss if not using learned variance
|
||||
return loss
|
||||
|
||||
# most of the code below is transcribed from
|
||||
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/diffusion_utils_2.py
|
||||
# the Improved DDPM paper then further modified it so that the mean is detached (shown a couple lines before), and weighted to be smaller than the l1 or l2 "simple" loss
|
||||
# it is questionable whether this is really needed, looking at some of the figures in the paper, but may as well stay faithful to their implementation
|
||||
|
||||
# if learning the variance, also include the extra weight kl loss
|
||||
|
||||
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||
|
||||
# kl loss with detached model predicted mean, for stability reasons as in paper
|
||||
|
||||
detached_model_mean = model_mean.detach()
|
||||
|
||||
kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)
|
||||
kl = meanflat(kl) * NAT
|
||||
|
||||
decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)
|
||||
decoder_nll = meanflat(decoder_nll) * NAT
|
||||
|
||||
# at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
||||
|
||||
vb_losses = torch.where(times == 0, decoder_nll, kl)
|
||||
|
||||
# weight the vb loss smaller, for stability, as in the paper (recommended 0.001)
|
||||
|
||||
vb_loss = vb_losses.mean() * self.vb_loss_weight
|
||||
|
||||
return loss + vb_loss
|
||||
|
||||
@torch.inference_mode()
|
||||
@eval_decorator
|
||||
@@ -1794,7 +1963,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
img = None
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
|
||||
|
||||
@@ -1820,6 +1989,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
text_mask = text_mask,
|
||||
cond_scale = cond_scale,
|
||||
predict_x_start = predict_x_start,
|
||||
learned_variance = learned_variance,
|
||||
clip_denoised = not is_latent_diffusion,
|
||||
lowres_cond_img = lowres_cond_img
|
||||
)
|
||||
@@ -1849,6 +2019,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
predict_x_start = self.predict_x_start[unet_index]
|
||||
random_crop_size = self.random_crop_sizes[unet_index]
|
||||
learned_variance = self.learned_variance[unet_index]
|
||||
b, c, h, w, device, = *image.shape, image.device
|
||||
|
||||
check_shape(image, 'b c h w', c = self.channels)
|
||||
@@ -1886,7 +2057,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
||||
|
||||
# main class
|
||||
|
||||
@@ -1936,4 +2107,3 @@ class DALLE2(nn.Module):
|
||||
return images[0]
|
||||
|
||||
return images
|
||||
|
||||
|
||||
@@ -105,6 +105,10 @@ class EMA(nn.Module):
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def restore_ema_model_device(self):
|
||||
device = self.initted.device
|
||||
self.ema_model.to(device)
|
||||
|
||||
def update(self):
|
||||
self.step += 1
|
||||
|
||||
@@ -305,6 +309,11 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
if self.use_ema:
|
||||
self.decoder.unets = trainable_unets # restore original training unets
|
||||
|
||||
# cast the ema_model unets back to original device
|
||||
for ema in self.ema_unets:
|
||||
ema.restore_ema_model_device()
|
||||
|
||||
return output
|
||||
|
||||
def forward(
|
||||
|
||||
3
setup.py
3
setup.py
@@ -10,11 +10,12 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.2.6',
|
||||
version = '0.2.17',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
long_description_content_type = 'text/markdown',
|
||||
url = 'https://github.com/lucidrains/dalle2-pytorch',
|
||||
keywords = [
|
||||
'artificial intelligence',
|
||||
|
||||
Reference in New Issue
Block a user