Compare commits

...

29 Commits

Author SHA1 Message Date
Phil Wang
8b9bbec7d1 project management 2022-05-01 09:32:57 -07:00
Phil Wang
1bb9fc9829 add convnext backbone for vqgan-vae, still need to fix groupnorms in resnet encdec 2022-05-01 09:32:24 -07:00
Phil Wang
5e421bd5bb let researchers do the hyperparameter search 2022-05-01 08:46:21 -07:00
Phil Wang
67fcab1122 add MLP based time conditioning to all convnexts, in addition to cross attention. also add an initial convolution, given convnext first depthwise conv 2022-05-01 08:41:02 -07:00
Phil Wang
5bfbccda22 port over vqgan vae trainer 2022-05-01 08:09:15 -07:00
Phil Wang
989275ff59 product management 2022-04-30 16:57:56 -07:00
Phil Wang
56408f4a40 project management 2022-04-30 16:57:02 -07:00
Phil Wang
d1a697ac23 allows one to shortcut sampling at a specific unet number, if one were to be training in stages 2022-04-30 16:05:13 -07:00
Phil Wang
ebe01749ed DecoderTrainer sample method uses the exponentially moving averaged 2022-04-30 14:55:34 -07:00
Phil Wang
63195cc2cb allow for division of loss prior to scaling, for gradient accumulation purposes 2022-04-30 12:56:47 -07:00
Phil Wang
a2ef69af66 take care of mixed precision, and make gradient accumulation do-able externally 2022-04-30 12:27:24 -07:00
Phil Wang
5fff22834e be able to finely customize learning parameters for each unet, take care of gradient clipping 2022-04-30 11:56:05 -07:00
Phil Wang
a9421f49ec simplify Decoder training for the public 2022-04-30 11:45:18 -07:00
Phil Wang
77fa34eae9 fix all clipping / clamping issues 2022-04-30 10:08:24 -07:00
Phil Wang
1c1e508369 fix all issues with text encodings conditioning in the decoder, using null padding tokens technique from dalle v1 2022-04-30 09:13:34 -07:00
Phil Wang
f19c99ecb0 fix decoder needing separate conditional dropping probabilities for image embeddings and text encodings, thanks to @xiankgx ! 2022-04-30 08:48:05 -07:00
Phil Wang
721a444686 Merge pull request #37 from ProGamerGov/patch-1
Fix spelling and grammatical errors
2022-04-30 08:19:07 -07:00
ProGamerGov
63450b466d Fix spelling and grammatical errors 2022-04-30 09:18:13 -06:00
Phil Wang
20e7eb5a9b cleanup 2022-04-30 07:22:57 -07:00
Phil Wang
e2f9615afa use @clip-anytorch , thanks to @rom1504 2022-04-30 06:40:54 -07:00
Phil Wang
0d1c07c803 fix a bug with classifier free guidance, thanks to @xiankgx again! 2022-04-30 06:34:57 -07:00
Phil Wang
a389f81138 todo 2022-04-29 15:40:51 -07:00
Phil Wang
0283556608 fix example in readme, since api changed 2022-04-29 13:40:55 -07:00
Phil Wang
5063d192b6 now completely OpenAI CLIP compatible for training
just take care of the logic for AdamW and transformers

used namedtuples for clip adapter embedding outputs
2022-04-29 13:05:01 -07:00
Phil Wang
f4a54e475e add some training fns 2022-04-29 09:44:55 -07:00
Phil Wang
fb662a62f3 fix another bug thanks to @xiankgx 2022-04-29 07:38:32 -07:00
Phil Wang
587c8c9b44 optimize for clarity 2022-04-28 21:59:13 -07:00
Phil Wang
aa900213e7 force first unet in the cascade to be conditioned on image embeds 2022-04-28 20:53:15 -07:00
Phil Wang
cb26187450 vqgan-vae codebook dims should be 256 or smaller 2022-04-28 08:59:03 -07:00
8 changed files with 1000 additions and 109 deletions

192
README.md
View File

@@ -47,7 +47,7 @@ clip = CLIP(
use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
use_visual_ssl = True, # whether to do self supervised learning on iages
use_visual_ssl = True, # whether to do self supervised learning on images
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
text_ssl_loss_weight = 0.05, # weight for text MLM loss
@@ -110,7 +110,8 @@ decoder = Decoder(
unet = unet,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()
# mock images (get a lot of this)
@@ -229,7 +230,8 @@ decoder = Decoder(
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000,
cond_drop_prob = 0.2
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()
# mock images (get a lot of this)
@@ -348,7 +350,8 @@ decoder = Decoder(
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()
@@ -430,8 +433,8 @@ images = torch.randn(4, 3, 256, 256).cuda()
# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone
clip_image_embeds = diffusion_prior.get_image_embed(images)
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
clip_image_embeds = diffusion_prior.clip.embed_image(images).image_embed
clip_text_embeds = diffusion_prior.clip.embed_text(text).text_embed
# feed text and images into diffusion prior network
@@ -495,6 +498,95 @@ loss.backward()
# now the diffusion prior can generate image embeddings from the text embeddings
```
## OpenAI CLIP
Although there is the possibility they are using an unreleased, more powerful CLIP, you can use one of the released ones, if you do not wish to train your own CLIP from scratch. This will also allow the community to more quickly validate the conclusions of the paper.
To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it into the `DiffusionPrior` or `Decoder` like so
```python
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
# openai pretrained clip - defaults to ViT/B-32
clip = OpenAIClipAdapter()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# prior networks (with transformer)
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
loss = diffusion_prior(text, images)
loss.backward()
# do above for many steps ...
# decoder (with unet)
unet1 = Unet(
dim = 128,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()
for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
dalle2 = DALLE2(
prior = diffusion_prior,
decoder = decoder
)
images = dalle2(
['a butterfly trying to escape a tornado'],
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)
# save your image (in this example, of size 256x256)
```
Now you'll just have to worry about training the Prior and the Decoder!
## Experimental
### DALL-E2 with Latent Diffusion
@@ -528,7 +620,7 @@ clip = CLIP(
# 3 unets for the decoder (a la cascading DDPM)
# first two unets are doing latent diffusion
# vqgan-vae must be trained before hand
# vqgan-vae must be trained beforehand
vae1 = VQGanVAE(
dim = 32,
@@ -581,7 +673,8 @@ decoder = Decoder(
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
timesteps = 100,
cond_drop_prob = 0.2
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda()
# mock images (get a lot of this)
@@ -615,7 +708,83 @@ images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
## Training wrapper (wip)
Offer training wrappers
### Decoder Training
Training the `Decoder` may be confusing, as one needs to keep track of an optimizer for each of the `Unet`(s) separately. Each `Unet` will also need its own corresponding exponential moving average. The `DecoderTrainer` hopes to make this simple, as shown below
```python
import torch
from dalle2_pytorch import DALLE2, Unet, Decoder, CLIP, DecoderTrainer
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# decoder (with unet)
unet1 = Unet(
dim = 128,
image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
text_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_text_encodings = True
).cuda()
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 1000,
condition_on_text_encodings = True
).cuda()
decoder_trainer = DecoderTrainer(
decoder,
lr = 3e-4,
wd = 1e-2,
ema_beta = 0.99,
ema_update_after_step = 1000,
ema_update_every = 10,
)
for unet_number in (1, 2):
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
loss.backward()
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
# after much training
# you can sample from the exponentially moving averaged unets as so
mock_image_embed = torch.randn(4, 512).cuda()
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
```
## CLI (wip)
@@ -648,13 +817,16 @@ Once built, images will be saved to the same directory the command is invoked
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
- [x] take care of mixed precision as well as gradient accumulation within decoder trainer
- [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
- [x] bring in tools to train vqgan-vae
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in tools to train vqgan-vae
## Citations

View File

@@ -1,4 +1,6 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.train import DecoderTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP

View File

@@ -3,6 +3,7 @@ from tqdm import tqdm
from inspect import isfunction
from functools import partial
from contextlib import contextmanager
from collections import namedtuple
import torch
import torch.nn.functional as F
@@ -90,8 +91,21 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://
return F.interpolate(t, size = shape, mode = mode, align_corners = False)
# image normalization functions
# ddpms expect images to be in the range of -1 to 1
# but CLIP may otherwise
def normalize_img(img):
return img * 2 - 1
def unnormalize_img(normed_img):
return (normed_img + 1) * 0.5
# clip related adapters
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask'])
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
class BaseClipAdapter(nn.Module):
def __init__(self, clip):
super().__init__()
@@ -109,6 +123,10 @@ class BaseClipAdapter(nn.Module):
def image_channels(self):
raise NotImplementedError
@property
def max_text_len(self):
raise NotImplementedError
def embed_text(self, text):
raise NotImplementedError
@@ -128,12 +146,18 @@ class XClipAdapter(BaseClipAdapter):
def image_channels(self):
return self.clip.image_channels
@property
def max_text_len(self):
return self.clip.text_seq_len
@torch.no_grad()
def embed_text(self, text):
text = text[..., :self.max_text_len]
text_mask = text != 0
encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
return l2norm(text_embed), text_encodings
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
@torch.no_grad()
def embed_image(self, image):
@@ -141,7 +165,69 @@ class XClipAdapter(BaseClipAdapter):
encoder_output = self.clip.visual_transformer(image)
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed), image_encodings
return EmbeddedImage(l2norm(image_embed), image_encodings)
class OpenAIClipAdapter(BaseClipAdapter):
def __init__(
self,
name = 'ViT-B/32'
):
import clip
openai_clip, preprocess = clip.load(name)
super().__init__(openai_clip)
text_attention_final = self.find_layer('ln_final')
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1]
self.cleared = False
def find_layer(self, layer):
modules = dict([*self.clip.named_modules()])
return modules.get(layer, None)
def clear(self):
if self.cleared:
return
self.handle()
def _hook(self, _, inputs, outputs):
self.text_encodings = outputs
@property
def dim_latent(self):
return 512
@property
def image_size(self):
return self.clip.visual.input_resolution
@property
def image_channels(self):
return 3
@property
def max_text_len(self):
return self.clip.context_length
@torch.no_grad()
def embed_text(self, text):
text = text[..., :self.max_text_len]
text_mask = text != 0
assert not self.cleared
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
del self.text_encodings
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
@torch.no_grad()
def embed_image(self, image):
assert not self.cleared
image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image))
image_embed = self.clip.encode_image(image)
return EmbeddedImage(image_embed.float(), None)
# classifier free guidance functions
@@ -223,7 +309,18 @@ class BaseGaussianDiffusion(nn.Module):
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
if loss_type == 'l1':
loss_fn = F.l1_loss
elif loss_type == 'l2':
loss_fn = F.mse_loss
elif loss_type == 'huber':
loss_fn = F.smooth_l1_loss
else:
raise NotImplementedError()
self.loss_type = loss_type
self.loss_fn = loss_fn
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
@@ -587,14 +684,14 @@ class DiffusionPriorNetwork(nn.Module):
# classifier free guidance
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
keep_mask = rearrange(keep_mask, 'b -> b 1')
mask &= cond_prob_mask
mask &= keep_mask
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
mask = torch.cat((mask, cond_prob_mask), dim = 1)
mask = torch.cat((mask, keep_mask), dim = 1)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
@@ -639,6 +736,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
predict_x_start = True,
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False
):
super().__init__(
beta_schedule = beta_schedule,
@@ -667,6 +765,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.predict_x_start = predict_x_start
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
# whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond)
@@ -680,6 +781,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
if clip_denoised and not self.predict_x_start:
x_recon.clamp_(-1., 1.)
if self.predict_x_start and self.sampling_clamp_l2norm:
x_recon = l2norm(x_recon)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@@ -703,29 +807,21 @@ class DiffusionPrior(BaseGaussianDiffusion):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
def p_losses(self, image_embed, t, text_cond, noise = None):
def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise)
image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise)
x_recon = self.net(
pred = self.net(
image_embed_noisy,
t,
times,
cond_drop_prob = self.cond_drop_prob,
**text_cond
)
to_predict = noise if not self.predict_x_start else image_embed
if self.loss_type == 'l1':
loss = F.l1_loss(to_predict, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(to_predict, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(to_predict, x_recon)
else:
raise NotImplementedError()
target = noise if not self.predict_x_start else image_embed
loss = self.loss_fn(pred, target)
return loss
@torch.no_grad()
@@ -738,12 +834,12 @@ class DiffusionPrior(BaseGaussianDiffusion):
batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim
text_embed, text_encodings = self.clip.embed_text(text)
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
text_cond = dict(text_embed = text_embed)
if self.condition_on_text_encodings:
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text != 0}
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embed']
@@ -780,8 +876,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
# calculate text conditionings, based on what is passed in
if exists(text):
text_embed, text_encodings = self.clip.embed_text(text)
text_mask = text != 0
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
text_cond = dict(text_embed = text_embed)
@@ -827,6 +922,7 @@ class ConvNextBlock(nn.Module):
dim_out,
*,
cond_dim = None,
time_cond_dim = None,
mult = 2,
norm = True
):
@@ -845,6 +941,14 @@ class ConvNextBlock(nn.Module):
)
)
self.time_mlp = None
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.GELU(),
nn.Linear(time_cond_dim, dim)
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
inner_dim = int(dim_out * mult)
@@ -857,9 +961,13 @@ class ConvNextBlock(nn.Module):
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
def forward(self, x, cond = None):
def forward(self, x, cond = None, time = None):
h = self.ds_conv(x)
if exists(time) and exists(self.time_mlp):
t = self.time_mlp(time)
h = rearrange(t, 'b c -> b c 1 1') + h
if exists(self.cross_attn):
assert exists(cond)
h = self.cross_attn(h, context = cond) + h
@@ -964,6 +1072,8 @@ class Unet(nn.Module):
cond_on_text_encodings = False,
max_text_len = 256,
cond_on_image_embeds = False,
init_dim = None,
init_conv_kernel_size = 7
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
@@ -981,28 +1091,45 @@ class Unet(nn.Module):
self.channels = channels
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim // 2)
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
assert (init_conv_kernel_size % 2) == 1
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# time, image embeddings, and optional text encoding
cond_dim = default(cond_dim, dim)
time_cond_dim = dim * 4
self.time_mlp = nn.Sequential(
self.to_time_hiddens = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, cond_dim * num_time_tokens),
nn.Linear(dim, time_cond_dim),
nn.GELU()
)
self.to_time_tokens = nn.Sequential(
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
)
self.to_time_cond = nn.Sequential(
nn.Linear(time_cond_dim, time_cond_dim)
)
self.image_to_cond = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity()
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
# text encoding conditioning (optional)
self.text_to_cond = None
if cond_on_text_encodings:
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
@@ -1013,6 +1140,8 @@ class Unet(nn.Module):
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
# attention related params
@@ -1031,26 +1160,26 @@ class Unet(nn.Module):
layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Upsample(dim_in)
]))
@@ -1066,13 +1195,14 @@ class Unet(nn.Module):
self,
*,
lowres_cond,
channels
channels,
cond_on_image_embeds
):
if lowres_cond == self.lowres_cond and channels == self.channels:
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
return self
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
return self.__class__(**updated_kwargs)
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
return self.__class__(**{**self._locals, **updated_kwargs})
def forward_with_cond_scale(
self,
@@ -1085,7 +1215,7 @@ class Unet(nn.Module):
if cond_scale == 1:
return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
@@ -1096,7 +1226,9 @@ class Unet(nn.Module):
image_embed,
lowres_cond_img = None,
text_encodings = None,
cond_drop_prob = 0.,
text_mask = None,
image_cond_drop_prob = 0.,
text_cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
):
@@ -1109,14 +1241,23 @@ class Unet(nn.Module):
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
# initial convolution
x = self.init_conv(x)
# time conditioning
time_tokens = self.time_mlp(time)
time_hiddens = self.to_time_hiddens(time)
time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens)
# conditional dropout
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
# mask out image embedding depending on condition dropout
# for classifier free guidance
@@ -1127,7 +1268,7 @@ class Unet(nn.Module):
image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where(
cond_prob_mask,
image_keep_mask,
image_tokens,
self.null_image_embed
)
@@ -1138,10 +1279,25 @@ class Unet(nn.Module):
if exists(text_encodings) and self.cond_on_text_encodings:
text_tokens = self.text_to_cond(text_encodings)
text_tokens = text_tokens[:, :self.max_text_len]
text_tokens_len = text_tokens.shape[1]
remainder = self.max_text_len - text_tokens_len
if remainder > 0:
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
if exists(text_mask):
if remainder > 0:
text_mask = F.pad(text_mask, (0, remainder), value = False)
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask = text_mask & text_keep_mask
text_tokens = torch.where(
cond_prob_mask,
text_keep_mask,
text_tokens,
self.null_text_embed[:, :text_tokens.shape[1]]
self.null_text_embed
)
# main conditioning tokens (c)
@@ -1161,24 +1317,24 @@ class Unet(nn.Module):
hiddens = []
for convnext, sparse_attn, convnext2, downsample in self.downs:
x = convnext(x, c)
x = convnext(x, c, t)
x = sparse_attn(x)
x = convnext2(x, c)
x = convnext2(x, c, t)
hiddens.append(x)
x = downsample(x)
x = self.mid_block1(x, mid_c)
x = self.mid_block1(x, mid_c, t)
if exists(self.mid_attn):
x = self.mid_attn(x)
x = self.mid_block2(x, mid_c)
x = self.mid_block2(x, mid_c, t)
for convnext, sparse_attn, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, c)
x = convnext(x, c, t)
x = sparse_attn(x)
x = convnext2(x, c)
x = convnext2(x, c, t)
x = upsample(x)
return self.final_conv(x)
@@ -1209,7 +1365,7 @@ class LowresConditioner(nn.Module):
target_image_size = cast_tuple(target_image_size, 2)
if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, mode = self.cond_upsample_mode)
if self.training:
# when training, blur the low resolution conditional image
@@ -1229,7 +1385,8 @@ class Decoder(BaseGaussianDiffusion):
clip,
vae = tuple(),
timesteps = 1000,
cond_drop_prob = 0.2,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
loss_type = 'l1',
beta_schedule = 'cosine',
predict_x_start = False,
@@ -1240,6 +1397,8 @@ class Decoder(BaseGaussianDiffusion):
blur_sigma = 0.1, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
clip_denoised = True,
clip_x_start = True
):
super().__init__(
beta_schedule = beta_schedule,
@@ -1279,6 +1438,7 @@ class Decoder(BaseGaussianDiffusion):
one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first,
cond_on_image_embeds = is_first,
channels = unet_channels
)
@@ -1312,7 +1472,13 @@ class Decoder(BaseGaussianDiffusion):
# classifier free guidance
self.cond_drop_prob = cond_drop_prob
self.image_cond_drop_prob = image_cond_drop_prob
self.text_cond_drop_prob = text_cond_drop_prob
# whether to clip when sampling
self.clip_denoised = clip_denoised
self.clip_x_start = clip_x_start
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
@@ -1336,37 +1502,34 @@ class Decoder(BaseGaussianDiffusion):
@torch.no_grad()
def get_image_embed(self, image):
image = resize_image_to(image, self.clip_image_size)
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed)
image_embed, _ = self.clip.embed_image(image)
return image_embed
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
if predict_x_start:
x_recon = pred
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised and not predict_x_start:
if clip_denoised:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
device = self.betas.device
b = shape[0]
@@ -1379,55 +1542,57 @@ class Decoder(BaseGaussianDiffusion):
torch.full((b,), i, device = device, dtype = torch.long),
image_embed = image_embed,
text_encodings = text_encodings,
text_mask = text_mask,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start
predict_x_start = predict_x_start,
clip_denoised = clip_denoised
)
return img
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
x_recon = unet(
pred = unet(
x_noisy,
t,
times,
image_embed = image_embed,
text_encodings = text_encodings,
text_mask = text_mask,
lowres_cond_img = lowres_cond_img,
cond_drop_prob = self.cond_drop_prob
image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob,
)
target = noise if not predict_x_start else x_start
if self.loss_type == 'l1':
loss = F.l1_loss(target, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(target, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(target, x_recon)
else:
raise NotImplementedError()
loss = self.loss_fn(pred, target)
return loss
@torch.no_grad()
@eval_decorator
def sample(self, image_embed, text = None, cond_scale = 1.):
def sample(
self,
image_embed,
text = None,
cond_scale = 1.,
stop_at_unet_number = None
):
batch_size = image_embed.shape[0]
text_encodings = None
text_encodings = text_mask = None
if exists(text):
_, text_encodings = self.clip.embed_text(text)
_, text_encodings, text_mask = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
img = None
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
@@ -1438,6 +1603,7 @@ class Decoder(BaseGaussianDiffusion):
if unet.lowres_cond:
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
is_latent_diffusion = isinstance(vae, VQGanVAE)
image_size = vae.get_encoded_fmap_size(image_size)
shape = (batch_size, vae.encoded_dim, image_size, image_size)
@@ -1449,13 +1615,18 @@ class Decoder(BaseGaussianDiffusion):
shape,
image_embed = image_embed,
text_encodings = text_encodings,
text_mask = text_mask,
cond_scale = cond_scale,
predict_x_start = predict_x_start,
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img
)
img = vae.decode(img)
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
break
return img
def forward(
@@ -1486,9 +1657,9 @@ class Decoder(BaseGaussianDiffusion):
if not exists(image_embed):
image_embed, _ = self.clip.embed_image(image)
text_encodings = None
text_encodings = text_mask = None
if exists(text) and not exists(text_encodings):
_, text_encodings = self.clip.embed_text(text)
_, text_encodings, text_mask = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
@@ -1503,7 +1674,7 @@ class Decoder(BaseGaussianDiffusion):
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
# main class
@@ -1553,4 +1724,3 @@ class DALLE2(nn.Module):
return images[0]
return images

View File

@@ -0,0 +1,29 @@
from torch.optim import AdamW, Adam
def separate_weight_decayable_params(params):
no_wd_params = set([param for param in params if param.ndim < 2])
wd_params = set(params) - no_wd_params
return wd_params, no_wd_params
def get_optimizer(
params,
lr = 3e-4,
wd = 1e-2,
betas = (0.9, 0.999),
filter_by_requires_grad = False
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
if wd == 0:
return Adam(params, lr = lr, betas = betas)
params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)
param_groups = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)

View File

@@ -1,6 +1,43 @@
import copy
from functools import partial
import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder
from dalle2_pytorch.optimizer import get_optimizer
# helper functions
def exists(val):
return val is not None
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# exponential moving average wrapper
@@ -9,16 +46,16 @@ class EMA(nn.Module):
self,
model,
beta = 0.99,
ema_update_after_step = 1000,
ema_update_every = 10,
update_after_step = 1000,
update_every = 10,
):
super().__init__()
self.beta = beta
self.online_model = model
self.ema_model = copy.deepcopy(model)
self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
self.ema_update_every = ema_update_every
self.update_after_step = update_after_step # only start EMA after this step number, starting at 0
self.update_every = update_every
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
@@ -26,7 +63,7 @@ class EMA(nn.Module):
def update(self):
self.step += 1
if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
if self.step <= self.update_after_step or (self.step % self.update_every) != 0:
return
if not self.initted:
@@ -51,3 +88,111 @@ class EMA(nn.Module):
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)
# trainers
class DecoderTrainer(nn.Module):
def __init__(
self,
decoder,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
max_grad_norm = None,
amp = False,
**kwargs
):
super().__init__()
assert isinstance(decoder, Decoder)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.decoder = decoder
self.num_unets = len(self.decoder.unets)
self.use_ema = use_ema
if use_ema:
has_lazy_linear = any([type(module) == nn.LazyLinear for module in decoder.modules()])
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
self.ema_unets = nn.ModuleList([])
self.amp = amp
# be able to finely customize learning rate, weight decay
# per unet
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
**kwargs
)
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
scaler = GradScaler(enabled = amp)
setattr(self, f'scaler{ind}', scaler)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def scale(self, loss, *, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
scaler = getattr(self, f'scaler{index}')
return scaler.scale(loss)
def update(self, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
unet = self.decoder.unets[index]
if exists(self.max_grad_norm):
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
optimizer = getattr(self, f'optim{index}')
scaler = getattr(self, f'scaler{index}')
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if self.use_ema:
ema_unet = self.ema_unets[index]
ema_unet.update()
@torch.no_grad()
def sample(self, *args, **kwargs):
if self.use_ema:
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = self.decoder.sample(*args, **kwargs)
if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets
return output
def forward(
self,
x,
*,
unet_number,
divisor = 1,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss / divisor, unet_number = unet_number)

View File

@@ -0,0 +1,266 @@
from math import sqrt
import copy
from random import choice
from pathlib import Path
from shutil import rmtree
import torch
from torch import nn
from PIL import Image
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.utils import make_grid, save_image
from einops import rearrange
from dalle2_pytorch.train import EMA
from dalle2_pytorch.vqgan_vae import VQGanVAE
from dalle2_pytorch.optimizer import get_optimizer
# helpers
def exists(val):
return val is not None
def noop(*args, **kwargs):
pass
def cycle(dl):
while True:
for data in dl:
yield data
def cast_tuple(t):
return t if isinstance(t, (tuple, list)) else (t,)
def yes_or_no(question):
answer = input(f'{question} (y/n) ')
return answer.lower() in ('yes', 'y')
def accum_log(log, new_logs):
for key, new_value in new_logs.items():
old_value = log.get(key, 0.)
log[key] = old_value + new_value
return log
# classes
class ImageDataset(Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png']
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
print(f'{len(self.paths)} training samples found at {folder}')
self.transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize(image_size),
T.RandomHorizontalFlip(),
T.CenterCrop(image_size),
T.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
# main trainer class
class VQGanVAETrainer(nn.Module):
def __init__(
self,
vae,
*,
num_train_steps,
lr,
batch_size,
folder,
grad_accum_every,
wd = 0.,
save_results_every = 100,
save_model_every = 1000,
results_folder = './results',
valid_frac = 0.05,
random_split_seed = 42,
ema_beta = 0.995,
ema_update_after_step = 2000,
ema_update_every = 10,
apply_grad_penalty_every = 4,
):
super().__init__()
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
image_size = vae.image_size
self.vae = vae
self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)
self.register_buffer('steps', torch.Tensor([0]))
self.num_train_steps = num_train_steps
self.batch_size = batch_size
self.grad_accum_every = grad_accum_every
all_parameters = set(vae.parameters())
discr_parameters = set(vae.discr.parameters())
vae_parameters = all_parameters - discr_parameters
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
# create dataset
self.ds = ImageDataset(folder, image_size = image_size)
# split for validation
if valid_frac > 0:
train_size = int((1 - valid_frac) * len(self.ds))
valid_size = len(self.ds) - train_size
self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
else:
self.valid_ds = self.ds
print(f'training with shared training and valid dataset of {len(self.ds)} samples')
# dataloader
self.dl = cycle(DataLoader(
self.ds,
batch_size = batch_size,
shuffle = True
))
self.valid_dl = cycle(DataLoader(
self.valid_ds,
batch_size = batch_size,
shuffle = True
))
self.save_model_every = save_model_every
self.save_results_every = save_results_every
self.apply_grad_penalty_every = apply_grad_penalty_every
self.results_folder = Path(results_folder)
if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
rmtree(str(self.results_folder))
self.results_folder.mkdir(parents = True, exist_ok = True)
def train_step(self):
device = next(self.vae.parameters()).device
steps = int(self.steps.item())
apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
self.vae.train()
# logs
logs = {}
# update vae (generator)
for _ in range(self.grad_accum_every):
img = next(self.dl)
img = img.to(device)
loss = self.vae(
img,
return_loss = True,
apply_grad_penalty = apply_grad_penalty
)
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward()
self.optim.step()
self.optim.zero_grad()
# update discriminator
if exists(self.vae.discr):
discr_loss = 0
for _ in range(self.grad_accum_every):
img = next(self.dl)
img = img.to(device)
loss = self.vae(img, return_discr_loss = True)
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward()
self.discr_optim.step()
self.discr_optim.zero_grad()
# log
print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")
# update exponential moving averaged generator
self.ema_vae.update()
# sample results every so often
if not (steps % self.save_results_every):
for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
model.eval()
imgs = next(self.dl)
imgs = imgs.to(device)
recons = model(imgs)
nrows = int(sqrt(self.batch_size))
imgs_and_recons = torch.stack((imgs, recons), dim = 0)
imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')
imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))
logs['reconstructions'] = grid
save_image(grid, str(self.results_folder / f'{filename}.png'))
print(f'{steps}: saving to {str(self.results_folder)}')
# save model every so often
if not (steps % self.save_model_every):
state_dict = self.vae.state_dict()
model_path = str(self.results_folder / f'vae.{steps}.pt')
torch.save(state_dict, model_path)
ema_state_dict = self.ema_vae.state_dict()
model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
torch.save(ema_state_dict, model_path)
print(f'{steps}: saving model to {str(self.results_folder)}')
self.steps += 1
return logs
def train(self, log_fn = noop):
device = next(self.vae.parameters()).device
while self.steps < self.num_train_steps:
logs = self.train_step()
log_fn(logs)
print('training complete')

View File

@@ -327,6 +327,108 @@ class ResBlock(nn.Module):
def forward(self, x):
return self.net(x) + x
# convnext enc dec
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g
class ConvNext(nn.Module):
def __init__(self, dim, mult = 4, kernel_size = 3, ds_kernel_size = 7):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Conv2d(dim, dim, ds_kernel_size, padding = ds_kernel_size // 2, groups = dim),
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, kernel_size, padding = kernel_size // 2),
nn.GELU(),
nn.Conv2d(inner_dim, dim, kernel_size, padding = kernel_size // 2)
)
def forward(self, x):
return self.net(x) + x
class ConvNextEncDec(nn.Module):
def __init__(
self,
dim,
*,
channels = 3,
layers = 4,
layer_mults = None,
num_blocks = 1,
first_conv_kernel_size = 5,
use_attn = True,
attn_dim_head = 64,
attn_heads = 8,
attn_dropout = 0.,
):
super().__init__()
self.layers = layers
self.encoders = MList([])
self.decoders = MList([])
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
self.encoded_dim = dims[-1]
dim_pairs = zip(dims[:-1], dims[1:])
append = lambda arr, t: arr.append(t)
prepend = lambda arr, t: arr.insert(0, t)
if not isinstance(num_blocks, tuple):
num_blocks = (*((0,) * (layers - 1)), num_blocks)
if not isinstance(use_attn, tuple):
use_attn = (*((False,) * (layers - 1)), use_attn)
assert len(num_blocks) == layers, 'number of blocks config must be equal to number of layers'
assert len(use_attn) == layers
for layer_index, (dim_in, dim_out), layer_num_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
if layer_use_attn:
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
for _ in range(layer_num_blocks):
append(self.encoders, ConvNext(dim_out))
prepend(self.decoders, ConvNext(dim_out))
if layer_use_attn:
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
append(self.decoders, nn.Conv2d(dim, channels, 1))
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
def encode(self, x):
for enc in self.encoders:
x = enc(x)
return x
def decode(self, x):
for dec in self.decoders:
x = dec(x)
return x
# vqgan attention layer
class VQGanAttention(nn.Module):
@@ -545,6 +647,7 @@ class VQGanVAE(nn.Module):
l2_recon_loss = False,
use_hinge_loss = True,
vgg = None,
vq_codebook_dim = 256,
vq_codebook_size = 512,
vq_decay = 0.8,
vq_commitment_weight = 1.,
@@ -567,6 +670,8 @@ class VQGanVAE(nn.Module):
enc_dec_klass = ResnetEncDec
elif vae_type == 'vit':
enc_dec_klass = ViTEncDec
elif vae_type == 'convnext':
enc_dec_klass = ConvNextEncDec
else:
raise ValueError(f'{vae_type} not valid')
@@ -579,6 +684,7 @@ class VQGanVAE(nn.Module):
self.vq = VQ(
dim = self.enc_dec.encoded_dim,
codebook_dim = vq_codebook_dim,
codebook_size = vq_codebook_size,
decay = vq_decay,
commitment_weight = vq_commitment_weight,

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.62',
version = '0.0.86',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -23,6 +23,7 @@ setup(
],
install_requires=[
'click',
'clip-anytorch',
'einops>=0.4',
'einops-exts>=0.0.3',
'kornia>=0.5.4',
@@ -31,7 +32,7 @@ setup(
'torchvision',
'tqdm',
'vector-quantize-pytorch',
'x-clip>=0.4.4',
'x-clip>=0.5.1',
'youtokentome'
],
classifiers=[