Compare commits

..

15 Commits

Author SHA1 Message Date
Phil Wang
2277b47ffd make sure learned variance can work for any number of unets in the decoder, defaults to first unet, as suggested was used in the paper 2022-05-12 14:18:15 -07:00
Phil Wang
28b58e568c cleanup in preparation of option for learned variance 2022-05-12 12:04:52 -07:00
Phil Wang
924455d97d align the ema model device back after sampling from the cascading ddpm in the decoder 2022-05-11 19:56:54 -07:00
Phil Wang
6021945fc8 default to l2 loss 2022-05-11 19:24:51 -07:00
Light-V
6f76652d11 fix typo in README.md (#85)
The default config for clip from openai should be ViT-B/32
2022-05-11 13:38:16 -07:00
Phil Wang
3dda2570ed fix amp issue for https://github.com/lucidrains/DALLE2-pytorch/issues/82 2022-05-11 08:21:39 -07:00
Phil Wang
2f3c02dba8 numerical accuracy for noise schedule parameters 2022-05-10 15:28:46 -07:00
Phil Wang
908088cfea wrap up cross embed layer feature 2022-05-10 12:19:34 -07:00
Phil Wang
8dc8a3de0d product management 2022-05-10 11:51:38 -07:00
Phil Wang
35f89556ba bring in the cross embed layer from Crossformer paper for initial convolution in unet 2022-05-10 11:50:38 -07:00
Phil Wang
2b55f753b9 fix new issue with github actions and auto pypi package uploading 2022-05-10 10:51:15 -07:00
Phil Wang
fc8fce38fb make sure cascading DDPM can be trained unconditionally, to ready for CLI one command training for the public 2022-05-10 10:48:10 -07:00
Phil Wang
a1bfb03ba4 project management 2022-05-10 10:13:51 -07:00
Phil Wang
b1e7b5f6bb make sure resnet groups in unet is finely customizable 2022-05-10 10:12:50 -07:00
z
10b905b445 smol typo (#81) 2022-05-10 09:52:50 -07:00
5 changed files with 255 additions and 72 deletions

View File

@@ -508,7 +508,7 @@ To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it i
import torch import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
# openai pretrained clip - defaults to ViT/B-32 # openai pretrained clip - defaults to ViT-B/32
clip = OpenAIClipAdapter() clip = OpenAIClipAdapter()
@@ -999,22 +999,24 @@ Once built, images will be saved to the same directory the command is invoked
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation) - [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias) - [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
- [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention - [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [x] cross embed layers for downsampling, as an option
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder - [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference - [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824 - [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove - [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a> - [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783 - [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes - [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training - [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt
## Citations ## Citations
@@ -1092,4 +1094,15 @@ Once built, images will be saved to the same directory the command is invoked
} }
``` ```
```bibtex
@misc{wang2021crossformer,
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
year = {2021},
eprint = {2108.00154},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -33,6 +33,10 @@ from rotary_embedding_torch import RotaryEmbedding
from x_clip import CLIP from x_clip import CLIP
from coca_pytorch import CoCa from coca_pytorch import CoCa
# constants
NAT = 1. / math.log(2.)
# helper functions # helper functions
def exists(val): def exists(val):
@@ -91,6 +95,9 @@ def freeze_model_and_make_eval_(model):
# tensor helpers # tensor helpers
def log(t, eps = 1e-12):
return torch.log(t.clamp(min = eps))
def l2norm(t): def l2norm(t):
return F.normalize(t, dim = -1) return F.normalize(t, dim = -1)
@@ -297,13 +304,43 @@ def noise_like(shape, device, repeat=False):
noise = lambda: torch.randn(shape, device=device) noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise() return repeat_noise() if repeat else noise()
def meanflat(x):
return x.mean(dim = tuple(range(1, len(x.shape))))
def normal_kl(mean1, logvar1, mean2, logvar2):
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))
def approx_standard_normal_cdf(x):
return 0.5 * (1.0 + torch.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3))))
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1. / 255.)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1. / 255.)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = log(cdf_plus)
log_one_minus_cdf_min = log(1. - cdf_min)
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(x < -thres,
log_cdf_plus,
torch.where(x > thres,
log_one_minus_cdf_min,
log(cdf_delta)))
return log_probs
def cosine_beta_schedule(timesteps, s = 0.008): def cosine_beta_schedule(timesteps, s = 0.008):
""" """
cosine schedule cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
""" """
steps = timesteps + 1 steps = timesteps + 1
x = torch.linspace(0, timesteps, steps) x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0] alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
@@ -314,21 +351,21 @@ def linear_beta_schedule(timesteps):
scale = 1000 / timesteps scale = 1000 / timesteps
beta_start = scale * 0.0001 beta_start = scale * 0.0001
beta_end = scale * 0.02 beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps) return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
def quadratic_beta_schedule(timesteps): def quadratic_beta_schedule(timesteps):
scale = 1000 / timesteps scale = 1000 / timesteps
beta_start = scale * 0.0001 beta_start = scale * 0.0001
beta_end = scale * 0.02 beta_end = scale * 0.02
return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2 return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
def sigmoid_beta_schedule(timesteps): def sigmoid_beta_schedule(timesteps):
scale = 1000 / timesteps scale = 1000 / timesteps
beta_start = scale * 0.0001 beta_start = scale * 0.0001
beta_end = scale * 0.02 beta_end = scale * 0.02
betas = torch.linspace(-6, 6, timesteps) betas = torch.linspace(-6, 6, timesteps, dtype = torch.float64)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
@@ -368,17 +405,21 @@ class BaseGaussianDiffusion(nn.Module):
self.loss_type = loss_type self.loss_type = loss_type
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.register_buffer('betas', betas) # register buffer helper function to cast double back to float
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others # calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0) # calculations for posterior q(x_{t-1} | x_t, x_0)
@@ -386,19 +427,13 @@ class BaseGaussianDiffusion(nn.Module):
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', posterior_variance) register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def q_posterior(self, x_start, x_t, t): def q_posterior(self, x_start, x_t, t):
posterior_mean = ( posterior_mean = (
@@ -827,7 +862,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
image_channels = 3, image_channels = 3,
timesteps = 1000, timesteps = 1000,
cond_drop_prob = 0., cond_drop_prob = 0.,
loss_type = "l1", loss_type = "l2",
predict_x_start = True, predict_x_start = True,
beta_schedule = "cosine", beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
@@ -1228,6 +1263,33 @@ class LinearAttention(nn.Module):
out = self.nonlin(out) out = self.nonlin(out)
return self.to_out(out) return self.to_out(out)
class CrossEmbedLayer(nn.Module):
def __init__(
self,
dim_in,
kernel_sizes,
dim_out = None,
stride = 2
):
super().__init__()
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
dim_out = default(dim_out, dim_in)
kernel_sizes = sorted(kernel_sizes)
num_scales = len(kernel_sizes)
# calculate the dimension at each scale
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
self.convs = nn.ModuleList([])
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
def forward(self, x):
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1)
class Unet(nn.Module): class Unet(nn.Module):
def __init__( def __init__(
self, self,
@@ -1241,6 +1303,7 @@ class Unet(nn.Module):
out_dim = None, out_dim = None,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8),
channels = 3, channels = 3,
channels_out = None,
attn_dim_head = 32, attn_dim_head = 32,
attn_heads = 16, attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
@@ -1251,8 +1314,10 @@ class Unet(nn.Module):
cond_on_image_embeds = False, cond_on_image_embeds = False,
init_dim = None, init_dim = None,
init_conv_kernel_size = 7, init_conv_kernel_size = 7,
block_type = 'resnet', resnet_groups = 8,
block_resnet_groups = 8, init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -1269,12 +1334,12 @@ class Unet(nn.Module):
# determine dimensions # determine dimensions
self.channels = channels self.channels = channels
self.channels_out = default(channels_out, channels)
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim // 2) init_dim = default(init_dim, dim // 3 * 2)
assert (init_conv_kernel_size % 2) == 1 self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)] dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:])) in_out = list(zip(dims[:-1], dims[1:]))
@@ -1330,7 +1395,15 @@ class Unet(nn.Module):
# resnet block klass # resnet block klass
block_klass = partial(ResnetBlock, groups = block_resnet_groups) resnet_groups = cast_tuple(resnet_groups, len(in_out))
assert len(resnet_groups) == len(in_out)
# downsample klass
downsample_klass = Downsample
if cross_embed_downsample:
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
# layers # layers
@@ -1338,39 +1411,38 @@ class Unet(nn.Module):
self.ups = nn.ModuleList([]) self.ups = nn.ModuleList([])
num_resolutions = len(in_out) num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out): for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
is_first = ind == 0 is_first = ind == 0
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([ self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_out, time_cond_dim = time_cond_dim), ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(), Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
block_klass(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Downsample(dim_out) if not is_last else nn.Identity() downsample_klass(dim_out) if not is_last else nn.Identity()
])) ]))
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim) self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim) self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
is_last = ind >= (num_resolutions - 2) is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([ self.ups.append(nn.ModuleList([
block_klass(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(), Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
block_klass(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Upsample(dim_in) Upsample(dim_in)
])) ]))
out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential( self.final_conv = nn.Sequential(
block_klass(dim, dim), ResnetBlock(dim, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, out_dim, 1) nn.Conv2d(dim, self.channels_out, 1)
) )
# if the current settings for the unet are not correct # if the current settings for the unet are not correct
@@ -1380,12 +1452,25 @@ class Unet(nn.Module):
*, *,
lowres_cond, lowres_cond,
channels, channels,
cond_on_image_embeds channels_out,
cond_on_image_embeds,
cond_on_text_encodings
): ):
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds: if lowres_cond == self.lowres_cond and \
channels == self.channels and \
cond_on_image_embeds == self.cond_on_image_embeds and \
cond_on_text_encodings == self.cond_on_text_encodings and \
channels_out == self.channels_out:
return self return self
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds} updated_kwargs = dict(
lowres_cond = lowres_cond,
channels = channels,
channels_out = channels_out,
cond_on_image_embeds = cond_on_image_embeds,
cond_on_text_encodings = cond_on_text_encodings
)
return self.__class__(**{**self._locals, **updated_kwargs}) return self.__class__(**{**self._locals, **updated_kwargs})
def forward_with_cond_scale( def forward_with_cond_scale(
@@ -1450,11 +1535,12 @@ class Unet(nn.Module):
if self.cond_on_image_embeds: if self.cond_on_image_embeds:
image_tokens = self.image_to_cond(image_embed) image_tokens = self.image_to_cond(image_embed)
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
image_tokens = torch.where( image_tokens = torch.where(
image_keep_mask, image_keep_mask,
image_tokens, image_tokens,
self.null_image_embed null_image_embed
) )
# take care of text encodings (optional) # take care of text encodings (optional)
@@ -1478,10 +1564,12 @@ class Unet(nn.Module):
text_mask = rearrange(text_mask, 'b n -> b n 1') text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask = text_mask & text_keep_mask text_keep_mask = text_mask & text_keep_mask
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
text_tokens = torch.where( text_tokens = torch.where(
text_keep_mask, text_keep_mask,
text_tokens, text_tokens,
self.null_text_embed null_text_embed
) )
# main conditioning tokens (c) # main conditioning tokens (c)
@@ -1569,7 +1657,7 @@ class Decoder(BaseGaussianDiffusion):
timesteps = 1000, timesteps = 1000,
image_cond_drop_prob = 0.1, image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5, text_cond_drop_prob = 0.5,
loss_type = 'l1', loss_type = 'l2',
beta_schedule = 'cosine', beta_schedule = 'cosine',
predict_x_start = False, predict_x_start = False,
predict_x_start_for_latent_diffusion = False, predict_x_start_for_latent_diffusion = False,
@@ -1581,7 +1669,10 @@ class Decoder(BaseGaussianDiffusion):
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
clip_denoised = True, clip_denoised = True,
clip_x_start = True, clip_x_start = True,
clip_adapter_overrides = dict() clip_adapter_overrides = dict(),
learned_variance = True,
vb_loss_weight = 0.001,
unconditional = False
): ):
super().__init__( super().__init__(
beta_schedule = beta_schedule, beta_schedule = beta_schedule,
@@ -1589,6 +1680,9 @@ class Decoder(BaseGaussianDiffusion):
loss_type = loss_type loss_type = loss_type
) )
self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)' assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = None self.clip = None
@@ -1616,10 +1710,18 @@ class Decoder(BaseGaussianDiffusion):
unets = cast_tuple(unet) unets = cast_tuple(unet)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels)) vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
self.learned_variance = learned_variance
self.vb_loss_weight = vb_loss_weight
# construct unets and vaes
self.unets = nn.ModuleList([]) self.unets = nn.ModuleList([])
self.vaes = nn.ModuleList([]) self.vaes = nn.ModuleList([])
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)): for ind, (one_unet, one_vae, one_unet_learned_var) in enumerate(zip(unets, vaes, learned_variance)):
assert isinstance(one_unet, Unet) assert isinstance(one_unet, Unet)
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE)) assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
@@ -1627,11 +1729,14 @@ class Decoder(BaseGaussianDiffusion):
latent_dim = one_vae.encoded_dim if exists(one_vae) else None latent_dim = one_vae.encoded_dim if exists(one_vae) else None
unet_channels = default(latent_dim, self.channels) unet_channels = default(latent_dim, self.channels)
unet_channels_out = unet_channels * (1 if not one_unet_learned_var else 2)
one_unet = one_unet.cast_model_parameters( one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first, lowres_cond = not is_first,
cond_on_image_embeds = is_first, cond_on_image_embeds = is_first and not unconditional,
channels = unet_channels cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
channels = unet_channels,
channels_out = unet_channels_out
) )
self.unets.append(one_unet) self.unets.append(one_unet)
@@ -1694,8 +1799,11 @@ class Decoder(BaseGaussianDiffusion):
yield yield
unet.cpu() unet.cpu()
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.): def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
if learned_variance:
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
if predict_x_start: if predict_x_start:
x_recon = pred x_recon = pred
@@ -1706,19 +1814,31 @@ class Decoder(BaseGaussianDiffusion):
x_recon.clamp_(-1., 1.) x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
if learned_variance:
# if learned variance, posterio variance and posterior log variance are predicted by the network
# by an interpolation of the max and min log beta values
# eq 15 - https://arxiv.org/abs/2102.09672
min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(self.betas), t, x.shape)
var_interp_frac = unnormalize_img(var_interp_frac_unnormalized)
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp()
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.inference_mode() @torch.inference_mode()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False): def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start) model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
noise = noise_like(x.shape, device, repeat_noise) noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.inference_mode() @torch.inference_mode()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1): def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
device = self.betas.device device = self.betas.device
b = shape[0] b = shape[0]
@@ -1735,17 +1855,18 @@ class Decoder(BaseGaussianDiffusion):
cond_scale = cond_scale, cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
learned_variance = learned_variance,
clip_denoised = clip_denoised clip_denoised = clip_denoised
) )
return img return img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None): def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise) x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
pred = unet( model_output = unet(
x_noisy, x_noisy,
times, times,
image_embed = image_embed, image_embed = image_embed,
@@ -1756,21 +1877,58 @@ class Decoder(BaseGaussianDiffusion):
text_cond_drop_prob = self.text_cond_drop_prob, text_cond_drop_prob = self.text_cond_drop_prob,
) )
if learned_variance:
pred, _ = model_output.chunk(2, dim = 1)
else:
pred = model_output
target = noise if not predict_x_start else x_start target = noise if not predict_x_start else x_start
loss = self.loss_fn(pred, target) loss = self.loss_fn(pred, target)
return loss
if not learned_variance:
# return simple loss if not using learned variance
return loss
# if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
# kl loss with detached model predicted mean, for stability reasons as in paper
detached_model_mean = model_mean.detach()
kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)
kl = meanflat(kl) * NAT
decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)
decoder_nll = meanflat(decoder_nll) * NAT
# at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
vb_losses = torch.where(times == 0, decoder_nll, kl)
# weight the vb loss smaller, for stability, as in the paper (recommended 0.001)
vb_loss = vb_losses.mean() * self.vb_loss_weight
return loss + vb_loss
@torch.inference_mode() @torch.inference_mode()
@eval_decorator @eval_decorator
def sample( def sample(
self, self,
image_embed, image_embed = None,
text = None, text = None,
batch_size = 1,
cond_scale = 1., cond_scale = 1.,
stop_at_unet_number = None stop_at_unet_number = None
): ):
batch_size = image_embed.shape[0] assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
if not self.unconditional:
batch_size = image_embed.shape[0]
text_encodings = text_mask = None text_encodings = text_mask = None
if exists(text): if exists(text):
@@ -1780,10 +1938,11 @@ class Decoder(BaseGaussianDiffusion):
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
img = None img = None
is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)): for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)):
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context() context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
with context: with context:
lowres_cond_img = None lowres_cond_img = None
@@ -1807,6 +1966,7 @@ class Decoder(BaseGaussianDiffusion):
text_mask = text_mask, text_mask = text_mask,
cond_scale = cond_scale, cond_scale = cond_scale,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion, clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img lowres_cond_img = lowres_cond_img
) )
@@ -1836,6 +1996,7 @@ class Decoder(BaseGaussianDiffusion):
target_image_size = self.image_sizes[unet_index] target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index] predict_x_start = self.predict_x_start[unet_index]
random_crop_size = self.random_crop_sizes[unet_index] random_crop_size = self.random_crop_sizes[unet_index]
learned_variance = self.learned_variance[unet_index]
b, c, h, w, device, = *image.shape, image.device b, c, h, w, device, = *image.shape, image.device
check_shape(image, 'b c h w', c = self.channels) check_shape(image, 'b c h w', c = self.channels)
@@ -1873,7 +2034,7 @@ class Decoder(BaseGaussianDiffusion):
if exists(lowres_cond_img): if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img) lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start) return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance)
# main class # main class
@@ -1923,4 +2084,3 @@ class DALLE2(nn.Module):
return images[0] return images[0]
return images return images

View File

@@ -105,6 +105,10 @@ class EMA(nn.Module):
self.register_buffer('initted', torch.Tensor([False])) self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.])) self.register_buffer('step', torch.tensor([0.]))
def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)
def update(self): def update(self):
self.step += 1 self.step += 1
@@ -305,6 +309,11 @@ class DecoderTrainer(nn.Module):
if self.use_ema: if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets self.decoder.unets = trainable_unets # restore original training unets
# cast the ema_model unets back to original device
for ema in self.ema_unets:
ema.restore_ema_model_device()
return output return output
def forward( def forward(

View File

@@ -10,11 +10,12 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.4', version = '0.2.14',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',
author_email = 'lucidrains@gmail.com', author_email = 'lucidrains@gmail.com',
long_description_content_type = 'text/markdown',
url = 'https://github.com/lucidrains/dalle2-pytorch', url = 'https://github.com/lucidrains/dalle2-pytorch',
keywords = [ keywords = [
'artificial intelligence', 'artificial intelligence',

View File

@@ -342,7 +342,7 @@ def main():
RESUME, RESUME,
DPRIOR_PATH, DPRIOR_PATH,
config, config,
atgs.wandb_entity, args.wandb_entity,
args.wandb_project, args.wandb_project,
args.learning_rate, args.learning_rate,
args.max_grad_norm, args.max_grad_norm,