Compare commits

...

29 Commits

Author SHA1 Message Date
Phil Wang
9773f10d6c use inference mode whenever possible, cleanup 2022-05-04 15:25:05 -07:00
Phil Wang
a6bf8ddef6 advertise laion 2022-05-04 15:04:05 -07:00
Phil Wang
86e692d24f fix random crop probability 2022-05-04 11:52:24 -07:00
Phil Wang
97b751209f allow for last unet in the cascade to be trained on crops, if it is convolution-only 2022-05-04 11:48:48 -07:00
Phil Wang
74103fd8d6 product management 2022-05-04 11:20:50 -07:00
Phil Wang
1992d25cad project management 2022-05-04 11:18:54 -07:00
Phil Wang
5b619c2fd5 make sure some hyperparameters for unet block is configurable 2022-05-04 11:18:32 -07:00
Phil Wang
9359ad2e91 0.0.95 2022-05-04 10:53:05 -07:00
Phil Wang
9ff228188b offer old resnet blocks, from the original DDPM paper, just in case convnexts are unsuitable for generative work 2022-05-04 10:52:58 -07:00
Kumar R
2d9963d30e Reporting metrics - Cosine similarity. (#55)
* Update train_diffusion_prior.py

* Delete train_diffusion_prior.py

* Cosine similarity logging.

* Update train_diffusion_prior.py

* Report Cosine metrics every N steps.
2022-05-04 08:04:36 -07:00
Phil Wang
58d9b422f3 0.0.94 2022-05-04 07:42:33 -07:00
Ray Bell
44b319cb57 add missing import (#56) 2022-05-04 07:42:20 -07:00
Phil Wang
c30f380689 final reminder 2022-05-03 08:18:53 -07:00
Phil Wang
e4e884bb8b keep all doors open 2022-05-03 08:17:02 -07:00
Phil Wang
803ad9c17d product management again 2022-05-03 08:15:25 -07:00
Phil Wang
a88dd6a9c0 todo 2022-05-03 08:09:02 -07:00
Kumar R
72c16b496e Update train_diffusion_prior.py (#53) 2022-05-02 22:44:57 -07:00
z
81d83dd7f2 defaults align with paper (#52)
Co-authored-by: nousr <>
2022-05-02 13:52:11 -07:00
Phil Wang
fa66f7e1e9 todo 2022-05-02 12:57:15 -07:00
Phil Wang
aa8d135245 allow laion to experiment with normformer in diffusion prior 2022-05-02 11:35:00 -07:00
Phil Wang
70282de23b add ability to turn on normformer settings, given @borisdayma reported good results and some personal anecdata 2022-05-02 11:33:15 -07:00
Phil Wang
83f761847e todo 2022-05-02 10:52:39 -07:00
Phil Wang
11469dc0c6 makes more sense to keep this as True as default, for stability 2022-05-02 10:50:55 -07:00
Romain Beaumont
2d25c89f35 Fix passing of l2norm_output to DiffusionPriorNetwork (#51) 2022-05-02 10:48:16 -07:00
Phil Wang
3fe96c208a add ability to train diffusion prior with l2norm on output image embed 2022-05-02 09:53:20 -07:00
Phil Wang
0fc6c9cdf3 provide option to l2norm the output of the diffusion prior 2022-05-02 09:41:03 -07:00
Phil Wang
7ee0ecc388 mixed precision for training diffusion prior + save optimizer and scaler states 2022-05-02 09:31:04 -07:00
Phil Wang
1924c7cc3d fix issue with mixed precision and gradient clipping 2022-05-02 09:20:19 -07:00
Phil Wang
f7df3caaf3 address not calculating average eval / test loss when training diffusion prior https://github.com/lucidrains/DALLE2-pytorch/issues/49 2022-05-02 08:51:41 -07:00
6 changed files with 270 additions and 66 deletions

View File

@@ -10,7 +10,7 @@ The main novelty seems to be an extra layer of indirection with the prior networ
This model is SOTA for text-to-image for now. This model is SOTA for text-to-image for now.
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place. There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
@@ -821,7 +821,9 @@ Once built, images will be saved to the same directory the command is invoked
- [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer - [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
- [x] bring in tools to train vqgan-vae - [x] bring in tools to train vqgan-vae
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet) - [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo)
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training - [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
@@ -831,6 +833,10 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14 - [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824 - [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove - [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] make sure resnet | convnext block hyperparameters can be configurable across unet depth (groups and expansion factor)
## Citations ## Citations
@@ -896,4 +902,14 @@ Once built, images will be saved to the same directory the command is invoked
} }
``` ```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a> ```bibtex
@article{Shleifer2021NormFormerIT,
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
author = {Sam Shleifer and Jason Weston and Myle Ott},
journal = {ArXiv},
year = {2021},
volume = {abs/2110.09456}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,6 +1,7 @@
import click import click
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from functools import reduce
from pathlib import Path from pathlib import Path
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior

View File

@@ -16,6 +16,7 @@ from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom from einops_exts.torch import EinopsToAndFrom
from kornia.filters import gaussian_blur2d from kornia.filters import gaussian_blur2d
import kornia.augmentation as K
from dalle2_pytorch.tokenizer import tokenizer from dalle2_pytorch.tokenizer import tokenizer
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
@@ -29,6 +30,9 @@ from x_clip import CLIP
def exists(val): def exists(val):
return val is not None return val is not None
def identity(t, *args, **kwargs):
return t
def default(val, d): def default(val, d):
if exists(val): if exists(val):
return val return val
@@ -496,7 +500,12 @@ class SwiGLU(nn.Module):
x, gate = x.chunk(2, dim = -1) x, gate = x.chunk(2, dim = -1)
return x * F.silu(gate) return x * F.silu(gate)
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False): def FeedForward(
dim,
mult = 4,
dropout = 0.,
post_activation_norm = False
):
""" post-activation norm https://arxiv.org/abs/2110.09456 """ """ post-activation norm https://arxiv.org/abs/2110.09456 """
inner_dim = int(mult * dim) inner_dim = int(mult * dim)
@@ -519,7 +528,8 @@ class Attention(nn.Module):
dim_head = 64, dim_head = 64,
heads = 8, heads = 8,
dropout = 0., dropout = 0.,
causal = False causal = False,
post_norm = False
): ):
super().__init__() super().__init__()
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
@@ -534,7 +544,11 @@ class Attention(nn.Module):
self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity()
)
def forward(self, x, mask = None, attn_bias = None): def forward(self, x, mask = None, attn_bias = None):
b, n, device = *x.shape[:2], x.device b, n, device = *x.shape[:2], x.device
@@ -596,10 +610,11 @@ class CausalTransformer(nn.Module):
dim_head = 64, dim_head = 64,
heads = 8, heads = 8,
ff_mult = 4, ff_mult = 4,
norm_out = False, norm_out = True,
attn_dropout = 0., attn_dropout = 0.,
ff_dropout = 0., ff_dropout = 0.,
final_proj = True final_proj = True,
normformer = False
): ):
super().__init__() super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads) self.rel_pos_bias = RelPosBias(heads = heads)
@@ -607,8 +622,8 @@ class CausalTransformer(nn.Module):
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout), Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
])) ]))
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
@@ -635,12 +650,14 @@ class DiffusionPriorNetwork(nn.Module):
self, self,
dim, dim,
num_timesteps = None, num_timesteps = None,
l2norm_output = False, # whether to restrict image embedding output with l2norm at the end (may make it easier to learn?)
**kwargs **kwargs
): ):
super().__init__() super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.learned_query = nn.Parameter(torch.randn(dim)) self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs) self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
self.l2norm_output = l2norm_output
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
@@ -719,7 +736,8 @@ class DiffusionPriorNetwork(nn.Module):
pred_image_embed = tokens[..., -1, :] pred_image_embed = tokens[..., -1, :]
return pred_image_embed output_fn = l2norm if self.l2norm_output else identity
return output_fn(pred_image_embed)
class DiffusionPrior(BaseGaussianDiffusion): class DiffusionPrior(BaseGaussianDiffusion):
def __init__( def __init__(
@@ -787,7 +805,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.inference_mode()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False): def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised) model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised)
@@ -796,7 +814,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad() @torch.inference_mode()
def p_sample_loop(self, shape, text_cond): def p_sample_loop(self, shape, text_cond):
device = self.betas.device device = self.betas.device
@@ -824,7 +842,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
loss = self.loss_fn(pred, target) loss = self.loss_fn(pred, target)
return loss return loss
@torch.no_grad() @torch.inference_mode()
@eval_decorator @eval_decorator
def sample(self, text, num_samples_per_batch = 2): def sample(self, text, num_samples_per_batch = 2):
# in the paper, what they did was # in the paper, what they did was
@@ -913,6 +931,72 @@ class SinusoidalPosEmb(nn.Module):
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1) return torch.cat((emb.sin(), emb.cos()), dim = -1)
class Block(nn.Module):
def __init__(
self,
dim,
dim_out,
groups = 8
):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(dim, dim_out, 3, padding = 1),
nn.GroupNorm(groups, dim_out),
nn.SiLU()
)
def forward(self, x):
return self.block(x)
class ResnetBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
*,
cond_dim = None,
time_cond_dim = None,
groups = 8
):
super().__init__()
self.time_mlp = None
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_cond_dim, dim_out)
)
self.cross_attn = None
if exists(cond_dim):
self.cross_attn = EinopsToAndFrom(
'b c h w',
'b (h w) c',
CrossAttention(
dim = dim_out,
context_dim = cond_dim
)
)
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, cond = None, time_emb = None):
h = self.block1(x)
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
h = rearrange(time_emb, 'b c -> b c 1 1') + h
if exists(self.cross_attn):
assert exists(cond)
h = self.cross_attn(h, context = cond) + h
h = self.block2(h)
return h + self.res_conv(x)
class ConvNextBlock(nn.Module): class ConvNextBlock(nn.Module):
""" https://arxiv.org/abs/2201.03545 """ """ https://arxiv.org/abs/2201.03545 """
@@ -923,8 +1007,7 @@ class ConvNextBlock(nn.Module):
*, *,
cond_dim = None, cond_dim = None,
time_cond_dim = None, time_cond_dim = None,
mult = 2, mult = 2
norm = True
): ):
super().__init__() super().__init__()
need_projection = dim != dim_out need_projection = dim != dim_out
@@ -953,7 +1036,7 @@ class ConvNextBlock(nn.Module):
inner_dim = int(dim_out * mult) inner_dim = int(dim_out * mult)
self.net = nn.Sequential( self.net = nn.Sequential(
ChanLayerNorm(dim) if norm else nn.Identity(), ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 3, padding = 1), nn.Conv2d(dim, inner_dim, 3, padding = 1),
nn.GELU(), nn.GELU(),
nn.Conv2d(inner_dim, dim_out, 3, padding = 1) nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
@@ -1065,7 +1148,11 @@ class LinearAttention(nn.Module):
self.nonlin = nn.GELU() self.nonlin = nn.GELU()
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False) self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1, bias = False),
ChanLayerNorm(dim)
)
def forward(self, fmap): def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:] h, x, y = self.heads, *fmap.shape[-2:]
@@ -1108,7 +1195,11 @@ class Unet(nn.Module):
max_text_len = 256, max_text_len = 256,
cond_on_image_embeds = False, cond_on_image_embeds = False,
init_dim = None, init_dim = None,
init_conv_kernel_size = 7 init_conv_kernel_size = 7,
block_type = 'resnet',
block_resnet_groups = 8,
block_convnext_mult = 2,
**kwargs
): ):
super().__init__() super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM # save locals to take care of some hyperparameters for cascading DDPM
@@ -1183,6 +1274,15 @@ class Unet(nn.Module):
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head) attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
# whether to use resnet or the (improved?) convnext blocks
if block_type == 'resnet':
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
elif block_type == 'convnext':
block_klass = partial(ConvNextBlock, mult = block_convnext_mult)
else:
raise ValueError(f'unimplemented block type {block_type}')
# layers # layers
self.downs = nn.ModuleList([]) self.downs = nn.ModuleList([])
@@ -1195,32 +1295,32 @@ class Unet(nn.Module):
layer_cond_dim = cond_dim if not is_first else None layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([ self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0), block_klass(dim_in, dim_out, time_cond_dim = time_cond_dim),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(), Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), block_klass(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity() Downsample(dim_out) if not is_last else nn.Identity()
])) ]))
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim) self.mid_block1 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim) self.mid_block2 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 2) is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([ self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), block_klass(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(), Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), block_klass(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Upsample(dim_in) Upsample(dim_in)
])) ]))
out_dim = default(out_dim, channels) out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential( self.final_conv = nn.Sequential(
ConvNextBlock(dim, dim), block_klass(dim, dim),
nn.Conv2d(dim, out_dim, 1) nn.Conv2d(dim, out_dim, 1)
) )
@@ -1351,10 +1451,10 @@ class Unet(nn.Module):
hiddens = [] hiddens = []
for convnext, sparse_attn, convnext2, downsample in self.downs: for block1, sparse_attn, block2, downsample in self.downs:
x = convnext(x, c, t) x = block1(x, c, t)
x = sparse_attn(x) x = sparse_attn(x)
x = convnext2(x, c, t) x = block2(x, c, t)
hiddens.append(x) hiddens.append(x)
x = downsample(x) x = downsample(x)
@@ -1365,11 +1465,11 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c, t) x = self.mid_block2(x, mid_c, t)
for convnext, sparse_attn, convnext2, upsample in self.ups: for block1, sparse_attn, block2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1) x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, c, t) x = block1(x, c, t)
x = sparse_attn(x) x = sparse_attn(x)
x = convnext2(x, c, t) x = block2(x, c, t)
x = upsample(x) x = upsample(x)
return self.final_conv(x) return self.final_conv(x)
@@ -1427,6 +1527,7 @@ class Decoder(BaseGaussianDiffusion):
predict_x_start = False, predict_x_start = False,
predict_x_start_for_latent_diffusion = False, predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = 0.1, # cascading ddpm - blur sigma blur_sigma = 0.1, # cascading ddpm - blur sigma
@@ -1489,6 +1590,10 @@ class Decoder(BaseGaussianDiffusion):
self.image_sizes = image_sizes self.image_sizes = image_sizes
self.sample_channels = cast_tuple(self.channels, len(image_sizes)) self.sample_channels = cast_tuple(self.channels, len(image_sizes))
# random crop sizes (for super-resoluting unets at the end of cascade?)
self.random_crop_sizes = cast_tuple(random_crop_sizes, len(image_sizes))
# predict x0 config # predict x0 config
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes)) self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
@@ -1534,12 +1639,6 @@ class Decoder(BaseGaussianDiffusion):
yield yield
unet.cpu() unet.cpu()
@torch.no_grad()
def get_image_embed(self, image):
image_embed, _ = self.clip.embed_image(image)
return image_embed
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.): def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
@@ -1554,7 +1653,7 @@ class Decoder(BaseGaussianDiffusion):
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.inference_mode()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False): def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start) model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
@@ -1563,7 +1662,7 @@ class Decoder(BaseGaussianDiffusion):
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad() @torch.inference_mode()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1): def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
device = self.betas.device device = self.betas.device
@@ -1607,7 +1706,7 @@ class Decoder(BaseGaussianDiffusion):
loss = self.loss_fn(pred, target) loss = self.loss_fn(pred, target)
return loss return loss
@torch.no_grad() @torch.inference_mode()
@eval_decorator @eval_decorator
def sample( def sample(
self, self,
@@ -1678,10 +1777,10 @@ class Decoder(BaseGaussianDiffusion):
unet = self.get_unet(unet_number) unet = self.get_unet(unet_number)
target_image_size = self.image_sizes[unet_index] vae = self.vaes[unet_index]
vae = self.vaes[unet_index] target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index] predict_x_start = self.predict_x_start[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
b, c, h, w, device, = *image.shape, image.device b, c, h, w, device, = *image.shape, image.device
check_shape(image, 'b c h w', c = self.channels) check_shape(image, 'b c h w', c = self.channels)
@@ -1702,6 +1801,14 @@ class Decoder(BaseGaussianDiffusion):
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size) image = resize_image_to(image, target_image_size)
if exists(random_crop_size):
aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)
# make sure low res conditioner and image both get augmented the same way
# detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
image = aug(image)
lowres_cond_img = aug(lowres_cond_img, params = aug._params)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
image = vae.encode(image) image = vae.encode(image)
@@ -1732,7 +1839,7 @@ class DALLE2(nn.Module):
self.to_pil = T.ToPILImage() self.to_pil = T.ToPILImage()
@torch.no_grad() @torch.inference_mode()
@eval_decorator @eval_decorator
def forward( def forward(
self, self,

View File

@@ -159,12 +159,13 @@ class DecoderTrainer(nn.Module):
index = unet_number - 1 index = unet_number - 1
unet = self.decoder.unets[index] unet = self.decoder.unets[index]
if exists(self.max_grad_norm):
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
optimizer = getattr(self, f'optim{index}') optimizer = getattr(self, f'optim{index}')
scaler = getattr(self, f'scaler{index}') scaler = getattr(self, f'scaler{index}')
if exists(self.max_grad_norm):
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.89', version = '0.0.99',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',

View File

@@ -1,36 +1,78 @@
import os import os
import math import math
import argparse import argparse
import numpy as np
import torch import torch
from torch import nn from torch import nn
from embedding_reader import EmbeddingReader from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.optimizer import get_optimizer
from torch.cuda.amp import autocast,GradScaler
import time import time
from tqdm import tqdm from tqdm import tqdm
import wandb import wandb
os.environ["WANDB_SILENT"] = "true" os.environ["WANDB_SILENT"] = "true"
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"): def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end), total_loss = 0.
total_samples = 0.
for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
text_reader(batch_size=batch_size, start=start, end=end)): text_reader(batch_size=batch_size, start=start, end=end)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device) emb_text_tensor = torch.tensor(emb_text[0]).to(device)
batches = emb_images_tensor.shape[0]
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor) loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
# Log to wandb total_loss += loss.item() * batches
wandb.log({f'{phase} {loss_type}': loss}) total_samples += batches
def save_model(save_path,state_dict): avg_loss = (total_loss / total_samples)
wandb.log({f'{phase} {loss_type}': avg_loss})
def save_model(save_path, state_dict):
# Saving State Dict # Saving State Dict
print("====================================== Saving checkpoint ======================================") print("====================================== Saving checkpoint ======================================")
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth') torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,val_set_size,NUM_TEST_EMBEDDINGS,device):
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
tstart = train_set_size+val_set_size
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
for embt, embi in zip(text_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend),image_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend)):
text_embed = torch.tensor(embt[0]).to(device)
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
test_text_cond = dict(text_embed = text_embed)
test_image_embeddings = torch.tensor(embi[0]).to(device)
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(dim=1, keepdim=True)
predicted_image_embeddings = diffusion_prior.p_sample_loop((NUM_TEST_EMBEDDINGS, 768), text_cond = test_text_cond)
predicted_image_embeddings = predicted_image_embeddings / predicted_image_embeddings.norm(dim=1, keepdim=True)
original_similarity = cos(text_embed,test_image_embeddings).cpu().numpy()
predicted_similarity = cos(text_embed,predicted_image_embeddings).cpu().numpy()
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity)})
return np.mean(predicted_similarity - original_similarity)
def train(image_embed_dim, def train(image_embed_dim,
image_embed_url, image_embed_url,
text_embed_url, text_embed_url,
@@ -43,6 +85,8 @@ def train(image_embed_dim,
clip, clip,
dp_condition_on_text_encodings, dp_condition_on_text_encodings,
dp_timesteps, dp_timesteps,
dp_l2norm_output,
dp_normformer,
dp_cond_drop_prob, dp_cond_drop_prob,
dpn_depth, dpn_depth,
dpn_dim_head, dpn_dim_head,
@@ -52,14 +96,17 @@ def train(image_embed_dim,
device, device,
learning_rate=0.001, learning_rate=0.001,
max_grad_norm=0.5, max_grad_norm=0.5,
weight_decay=0.01): weight_decay=0.01,
amp=False):
# DiffusionPriorNetwork # DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork( prior_network = DiffusionPriorNetwork(
dim = image_embed_dim, dim = image_embed_dim,
depth = dpn_depth, depth = dpn_depth,
dim_head = dpn_dim_head, dim_head = dpn_dim_head,
heads = dpn_heads).to(device) heads = dpn_heads,
normformer = dp_normformer,
l2norm_output = dp_l2norm_output).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed # DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior( diffusion_prior = DiffusionPrior(
@@ -82,6 +129,7 @@ def train(image_embed_dim,
os.makedirs(save_path) os.makedirs(save_path)
### Training code ### ### Training code ###
scaler = GradScaler(enabled=amp)
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate) optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
epochs = num_epochs epochs = num_epochs
@@ -98,23 +146,45 @@ def train(image_embed_dim,
text_reader(batch_size=batch_size, start=0, end=train_set_size)): text_reader(batch_size=batch_size, start=0, end=train_set_size)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device) emb_text_tensor = torch.tensor(emb_text[0]).to(device)
optimizer.zero_grad()
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor) with autocast(enabled=amp):
loss.backward() loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
scaler.scale(loss).backward()
# Samples per second # Samples per second
step+=1 step+=1
samples_per_sec = batch_size*step/(time.time()-t) samples_per_sec = batch_size*step/(time.time()-t)
# Save checkpoint every save_interval minutes # Save checkpoint every save_interval minutes
if(int(time.time()-t) >= 60*save_interval): if(int(time.time()-t) >= 60*save_interval):
t = time.time() t = time.time()
save_model(save_path,diffusion_prior.state_dict())
save_model(
save_path,
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
# Log to wandb # Log to wandb
wandb.log({"Training loss": loss.item(), wandb.log({"Training loss": loss.item(),
"Steps": step, "Steps": step,
"Samples per second": samples_per_sec}) "Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0:
diff_cosine_sim = report_cosine_sims(diffusion_prior,
image_reader,
text_reader,
train_set_size,
val_set_size,
NUM_TEST_EMBEDDINGS,
device)
wandb.log({"Cosine similarity difference": diff_cosine_sim})
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm) scaler.unscale_(optimizer)
optimizer.step() nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
### Evaluate model(validation run) ### ### Evaluate model(validation run) ###
start = train_set_size start = train_set_size
@@ -139,8 +209,8 @@ def main():
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
# Hyperparameters # Hyperparameters
parser.add_argument("--learning-rate", type=float, default=0.001) parser.add_argument("--learning-rate", type=float, default=1.1e-4)
parser.add_argument("--weight-decay", type=float, default=0.01) parser.add_argument("--weight-decay", type=float, default=6.02e-2)
parser.add_argument("--max-grad-norm", type=float, default=0.5) parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--batch-size", type=int, default=10**4) parser.add_argument("--batch-size", type=int, default=10**4)
parser.add_argument("--num-epochs", type=int, default=5) parser.add_argument("--num-epochs", type=int, default=5)
@@ -158,15 +228,20 @@ def main():
# DiffusionPrior(dp) parameters # DiffusionPrior(dp) parameters
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False) parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
parser.add_argument("--dp-timesteps", type=int, default=100) parser.add_argument("--dp-timesteps", type=int, default=100)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2) parser.add_argument("--dp-l2norm-output", type=bool, default=False)
parser.add_argument("--dp-normformer", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1)
parser.add_argument("--dp-loss-type", type=str, default="l2") parser.add_argument("--dp-loss-type", type=str, default="l2")
parser.add_argument("--clip", type=str, default=None) parser.add_argument("--clip", type=str, default=None)
parser.add_argument("--amp", type=bool, default=False)
# Model checkpointing interval(minutes) # Model checkpointing interval(minutes)
parser.add_argument("--save-interval", type=int, default=30) parser.add_argument("--save-interval", type=int, default=30)
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints") parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
args = parser.parse_args() args = parser.parse_args()
print("Setting up wandb logging... Please wait...") print("Setting up wandb logging... Please wait...")
wandb.init( wandb.init(
entity=args.wandb_entity, entity=args.wandb_entity,
project=args.wandb_project, project=args.wandb_project,
@@ -176,6 +251,7 @@ def main():
"dataset": args.wandb_dataset, "dataset": args.wandb_dataset,
"epochs": args.num_epochs, "epochs": args.num_epochs,
}) })
print("wandb logging setup done!") print("wandb logging setup done!")
# Obtain the utilized device. # Obtain the utilized device.
@@ -197,6 +273,8 @@ def main():
args.clip, args.clip,
args.dp_condition_on_text_encodings, args.dp_condition_on_text_encodings,
args.dp_timesteps, args.dp_timesteps,
args.dp_l2norm_output,
args.dp_normformer,
args.dp_cond_drop_prob, args.dp_cond_drop_prob,
args.dpn_depth, args.dpn_depth,
args.dpn_dim_head, args.dpn_dim_head,
@@ -206,7 +284,8 @@ def main():
device, device,
args.learning_rate, args.learning_rate,
args.max_grad_norm, args.max_grad_norm,
args.weight_decay) args.weight_decay,
args.amp)
if __name__ == "__main__": if __name__ == "__main__":
main() main()