Compare commits

..

22 Commits

Author SHA1 Message Date
Phil Wang
bb86ab2404 update sample, and set default gradient clipping value for decoder training 2022-05-16 17:38:30 -07:00
Phil Wang
ae056dd67c samples 2022-05-16 13:46:35 -07:00
Phil Wang
033d6b0ce8 last update 2022-05-16 13:38:33 -07:00
Phil Wang
c7ea8748db default decoder learning rate to what was in the paper 2022-05-16 13:33:54 -07:00
Phil Wang
13382885d9 final update to dalle2 repository for a while - sampling from prior in chunks automatically with max_batch_size keyword given 2022-05-16 12:57:31 -07:00
Phil Wang
c3d4a7ffe4 update working unconditional decoder example 2022-05-16 12:50:07 -07:00
Phil Wang
164d9be444 use a decorator and take care of sampling in chunks (max_batch_size keyword), in case one is sampling a huge grid of images 2022-05-16 12:34:28 -07:00
Phil Wang
5562ec6be2 status updates 2022-05-16 12:01:54 -07:00
Phil Wang
89ff04cfe2 final tweak to EMA class 2022-05-16 11:54:34 -07:00
Phil Wang
f4016f6302 allow for overriding use of EMA during sampling in decoder trainer with use_non_ema keyword, also fix some issues with automatic normalization of images and low res conditioning image if latent diffusion is in play 2022-05-16 11:18:30 -07:00
Phil Wang
1212f7058d allow text encodings and text mask to be passed in on forward and sampling for Decoder class 2022-05-16 10:40:32 -07:00
Phil Wang
dab106d4e5 back to no_grad for now, also keep track and restore unet devices in one_unet_in_gpu contextmanager 2022-05-16 09:36:14 -07:00
Phil Wang
bb151ca6b1 unet_number on decoder trainer only needs to be passed in if there is greater than 1 unet, so that unconditional training of a single ddpm is seamless (experiment in progress locally) 2022-05-16 09:17:17 -07:00
zion
4a59dea4cf Migrate to text-conditioned prior training (#95)
* migrate to conditioned prior

* unify reader logic with a wrapper (#1)

* separate out reader logic

* support both training methods

* Update train prior to use embedding wrapper (#3)

* Support Both Methods

* bug fixes

* small bug fixes

* embedding only wrapper bug

* use smaller val perc

* final bug fix for embedding-only

Co-authored-by: nousr <>
2022-05-15 20:16:38 -07:00
Phil Wang
ecf9e8027d make sure classifier free guidance is used only if conditional dropout is present on the DiffusionPrior and Decoder classes. also make sure prior can have a different conditional scale than decoder 2022-05-15 19:09:38 -07:00
Phil Wang
36c5079bd7 LazyLinear is not mature, make users pass in text_embed_dim if text conditioning is turned on 2022-05-15 18:56:52 -07:00
Phil Wang
4a4c7ac9e6 cond drop prob for diffusion prior network should default to 0 2022-05-15 18:47:45 -07:00
Phil Wang
fad7481479 todo 2022-05-15 17:00:25 -07:00
Phil Wang
123658d082 cite Ho et al, since cascading ddpm is now trainable 2022-05-15 16:56:53 -07:00
Phil Wang
11d4e11f10 allow for training unconditional ddpm or cascading ddpms 2022-05-15 16:54:56 -07:00
Phil Wang
99778e12de trainer classes now takes care of auto-casting numpy to torch tensors, and setting correct device based on model parameter devices 2022-05-15 15:25:45 -07:00
Phil Wang
0f0011caf0 todo 2022-05-15 14:28:35 -07:00
12 changed files with 608 additions and 164 deletions

View File

@@ -14,6 +14,16 @@ Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place. There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
## Status
- A research group has used the code in this repository to train a functional diffusion prior for their CLIP generations. Will share their work once they release their preprint. This, and <a href="https://github.com/crowsonkb">Katherine's</a> own experiments, validate OpenAI's finding that the extra prior increases variety of generations.
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
<img src="./samples/oxford.png" width="600px" />
*ongoing at 21k steps*
## Install ## Install
```bash ```bash
@@ -706,7 +716,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024) images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
``` ```
## Training wrapper (wip) ## Training wrapper
### Decoder Training ### Decoder Training
@@ -814,8 +824,8 @@ clip = CLIP(
# mock data # mock data
text = torch.randint(0, 49408, (32, 256)).cuda() text = torch.randint(0, 49408, (512, 256)).cuda()
images = torch.randn(32, 3, 256, 256).cuda() images = torch.randn(512, 3, 256, 256).cuda()
# prior networks (with transformer) # prior networks (with transformer)
@@ -848,9 +858,64 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
# after much of the above three lines in a loop # after much of the above three lines in a loop
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior # you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings image_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings
``` ```
## Bonus
### Unconditional Training
The repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`
ex.
```python
import torch
from dalle2_pytorch import Unet, Decoder, DecoderTrainer
# unet for the cascading ddpm
unet1 = Unet(
dim = 128,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 32,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
# decoder, which contains the unets
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (256, 512), # first unet up to 256px, then second to 512px
timesteps = 1000,
unconditional = True
).cuda()
# decoder trainer
decoder_trainer = DecoderTrainer(decoder)
# images (get a lot of this)
images = torch.randn(1, 3, 512, 512).cuda()
# feed images into decoder
for i in (1, 2):
loss = decoder_trainer(images, unet_number = i)
decoder_trainer.update(unet_number = i)
# do the above for many many many many images
# then it will learn to generate images
images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)
```
## Dataloaders
### Decoder Dataloaders ### Decoder Dataloaders
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network. In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
@@ -1013,6 +1078,8 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] decoder needs one day worth of refactor for tech debt - [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well - [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly - [ ] for all model classes with hyperparameters that changes the network architecture, make it requirement that they must expose a config property, and write a simple function that asserts that it restores the object correctly
- [ ] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
## Citations ## Citations
@@ -1101,4 +1168,13 @@ Once built, images will be saved to the same directory the command is invoked
} }
``` ```
```bibtex
@article{ho2021cascaded,
title = {Cascaded Diffusion Models for High Fidelity Image Generation},
author = {Ho, Jonathan and Saharia, Chitwan and Chan, William and Fleet, David J and Norouzi, Mohammad and Salimans, Tim},
journal = {arXiv preprint arXiv:2106.15282},
year = {2021}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,6 +1,6 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP from x_clip import CLIP

View File

@@ -61,6 +61,9 @@ def default(val, d):
def cast_tuple(val, length = 1): def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length) return val if isinstance(val, tuple) else ((val,) * length)
def module_device(module):
return next(module.parameters()).device
@contextmanager @contextmanager
def null_context(*args, **kwargs): def null_context(*args, **kwargs):
yield yield
@@ -794,7 +797,7 @@ class DiffusionPriorNetwork(nn.Module):
text_embed, text_embed,
text_encodings = None, text_encodings = None,
mask = None, mask = None,
cond_drop_prob = 0.2 cond_drop_prob = 0.
): ):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
@@ -901,6 +904,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.channels = default(image_channels, lambda: clip.image_channels) self.channels = default(image_channels, lambda: clip.image_channels)
self.cond_drop_prob = cond_drop_prob self.cond_drop_prob = cond_drop_prob
self.can_classifier_guidance = cond_drop_prob > 0.
self.condition_on_text_encodings = condition_on_text_encodings self.condition_on_text_encodings = condition_on_text_encodings
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
@@ -914,8 +918,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.training_clamp_l2norm = training_clamp_l2norm self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm self.init_image_embed_l2norm = init_image_embed_l2norm
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
pred = self.net(x, t, **text_cond) assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
if self.predict_x_start: if self.predict_x_start:
x_recon = pred x_recon = pred
@@ -933,17 +939,17 @@ class DiffusionPrior(BaseGaussianDiffusion):
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.inference_mode() @torch.no_grad()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False): def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised) model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
noise = noise_like(x.shape, device, repeat_noise) noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.inference_mode() @torch.no_grad()
def p_sample_loop(self, shape, text_cond): def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
device = self.betas.device device = self.betas.device
b = shape[0] b = shape[0]
@@ -954,7 +960,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
times = torch.full((b,), i, device = device, dtype = torch.long) times = torch.full((b,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond) image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
return image_embed return image_embed
@@ -978,21 +984,21 @@ class DiffusionPrior(BaseGaussianDiffusion):
loss = self.loss_fn(pred, target) loss = self.loss_fn(pred, target)
return loss return loss
@torch.inference_mode() @torch.no_grad()
@eval_decorator @eval_decorator
def sample_batch_size(self, batch_size, text_cond): def sample_batch_size(self, batch_size, text_cond, cond_scale = 1.):
device = self.betas.device device = self.betas.device
shape = (batch_size, self.image_embed_dim) shape = (batch_size, self.image_embed_dim)
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond) img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)
return img return img
@torch.inference_mode() @torch.no_grad()
@eval_decorator @eval_decorator
def sample(self, text, num_samples_per_batch = 2): def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.):
# in the paper, what they did was # in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP # sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch) text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
@@ -1007,7 +1013,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
if self.condition_on_text_encodings: if self.condition_on_text_encodings:
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale)
# retrieve original unscaled image embed # retrieve original unscaled image embed
@@ -1305,7 +1311,7 @@ class Unet(nn.Module):
self, self,
dim, dim,
*, *,
image_embed_dim, image_embed_dim = None,
text_embed_dim = None, text_embed_dim = None,
cond_dim = None, cond_dim = None,
num_image_tokens = 4, num_image_tokens = 4,
@@ -1377,7 +1383,7 @@ class Unet(nn.Module):
self.image_to_cond = nn.Sequential( self.image_to_cond = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens), nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens) Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity() ) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
self.norm_cond = nn.LayerNorm(cond_dim) self.norm_cond = nn.LayerNorm(cond_dim)
self.norm_mid_cond = nn.LayerNorm(cond_dim) self.norm_mid_cond = nn.LayerNorm(cond_dim)
@@ -1387,7 +1393,8 @@ class Unet(nn.Module):
self.text_to_cond = None self.text_to_cond = None
if cond_on_text_encodings: if cond_on_text_encodings:
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim) assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
# finer control over whether to condition on image embeddings and text encodings # finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting # so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
@@ -1701,7 +1708,7 @@ class Decoder(BaseGaussianDiffusion):
self.unconditional = unconditional self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present' assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)' assert self.unconditional or (exists(clip) ^ exists(image_size)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = None self.clip = None
if exists(clip): if exists(clip):
@@ -1792,6 +1799,7 @@ class Decoder(BaseGaussianDiffusion):
self.image_cond_drop_prob = image_cond_drop_prob self.image_cond_drop_prob = image_cond_drop_prob
self.text_cond_drop_prob = text_cond_drop_prob self.text_cond_drop_prob = text_cond_drop_prob
self.can_classifier_guidance = image_cond_drop_prob > 0. or text_cond_drop_prob > 0.
# whether to clip when sampling # whether to clip when sampling
@@ -1811,13 +1819,19 @@ class Decoder(BaseGaussianDiffusion):
unet = self.get_unet(unet_number) unet = self.get_unet(unet_number)
self.cuda() self.cuda()
self.unets.cpu()
devices = [module_device(unet) for unet in self.unets]
self.unets.cpu()
unet.cuda() unet.cuda()
yield yield
unet.cpu()
for unet, device in zip(self.unets, devices):
unet.to(device)
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None): def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)) pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
if learned_variance: if learned_variance:
@@ -1846,7 +1860,7 @@ class Decoder(BaseGaussianDiffusion):
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.inference_mode() @torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False): def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance) model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
@@ -1855,14 +1869,15 @@ class Decoder(BaseGaussianDiffusion):
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.inference_mode() @torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1): def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
device = self.betas.device device = self.betas.device
b = shape[0] b = shape[0]
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img) if not is_latent_diffusion:
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample( img = self.p_sample(
@@ -1882,13 +1897,14 @@ class Decoder(BaseGaussianDiffusion):
unnormalize_img = unnormalize_zero_to_one(img) unnormalize_img = unnormalize_zero_to_one(img)
return unnormalize_img return unnormalize_img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False): def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1] # normalize to [-1, 1]
x_start = normalize_neg_one_to_one(x_start) if not is_latent_diffusion:
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img) x_start = normalize_neg_one_to_one(x_start)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
# get x_t # get x_t
@@ -1948,12 +1964,14 @@ class Decoder(BaseGaussianDiffusion):
return loss + vb_loss return loss + vb_loss
@torch.inference_mode() @torch.no_grad()
@eval_decorator @eval_decorator
def sample( def sample(
self, self,
image_embed = None, image_embed = None,
text = None, text = None,
text_mask = None,
text_encodings = None,
batch_size = 1, batch_size = 1,
cond_scale = 1., cond_scale = 1.,
stop_at_unet_number = None stop_at_unet_number = None
@@ -1963,8 +1981,8 @@ class Decoder(BaseGaussianDiffusion):
if not self.unconditional: if not self.unconditional:
batch_size = image_embed.shape[0] batch_size = image_embed.shape[0]
text_encodings = text_mask = None if exists(text) and not exists(text_encodings) and not self.unconditional:
if exists(text): assert exists(self.clip)
_, text_encodings, text_mask = self.clip.embed_text(text) _, text_encodings, text_mask = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
@@ -2000,7 +2018,8 @@ class Decoder(BaseGaussianDiffusion):
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
learned_variance = learned_variance, learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion, clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img lowres_cond_img = lowres_cond_img,
is_latent_diffusion = is_latent_diffusion
) )
img = vae.decode(img) img = vae.decode(img)
@@ -2016,6 +2035,7 @@ class Decoder(BaseGaussianDiffusion):
text = None, text = None,
image_embed = None, image_embed = None,
text_encodings = None, text_encodings = None,
text_mask = None,
unet_number = None unet_number = None
): ):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
@@ -2036,12 +2056,11 @@ class Decoder(BaseGaussianDiffusion):
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
if not exists(image_embed): if not exists(image_embed) and not self.unconditional:
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init' assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
image_embed, _ = self.clip.embed_image(image) image_embed, _ = self.clip.embed_image(image)
text_encodings = text_mask = None if exists(text) and not exists(text_encodings) and not self.unconditional:
if exists(text) and not exists(text_encodings):
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder' assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
_, text_encodings, text_mask = self.clip.embed_text(text) _, text_encodings, text_mask = self.clip.embed_text(text)
@@ -2059,12 +2078,14 @@ class Decoder(BaseGaussianDiffusion):
image = aug(image) image = aug(image)
lowres_cond_img = aug(lowres_cond_img, params = aug._params) lowres_cond_img = aug(lowres_cond_img, params = aug._params)
is_latent_diffusion = not isinstance(vae, NullVQGanVAE)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
image = vae.encode(image) image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img) lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance) return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion)
# main class # main class
@@ -2087,22 +2108,23 @@ class DALLE2(nn.Module):
self.to_pil = T.ToPILImage() self.to_pil = T.ToPILImage()
@torch.inference_mode() @torch.no_grad()
@eval_decorator @eval_decorator
def forward( def forward(
self, self,
text, text,
cond_scale = 1., cond_scale = 1.,
prior_cond_scale = 1.,
return_pil_images = False return_pil_images = False
): ):
device = next(self.parameters()).device device = module_device(self)
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1) one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
if isinstance(text, str) or is_list_str(text): if isinstance(text, str) or is_list_str(text):
text = [text] if not isinstance(text, (list, tuple)) else text text = [text] if not isinstance(text, (list, tuple)) else text
text = tokenizer.tokenize(text).to(device) text = tokenizer.tokenize(text).to(device)
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples) image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
text_cond = text if self.decoder_need_text_cond else None text_cond = text if self.decoder_need_text_cond else None
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale) images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)

View File

@@ -1 +1,2 @@
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits

View File

@@ -0,0 +1,180 @@
from torch.utils.data import IterableDataset
from torch import from_numpy
from clip import tokenize
from embedding_reader import EmbeddingReader
class PriorEmbeddingLoader(IterableDataset):
def __init__(
self,
text_conditioned: bool,
batch_size: int,
start: int,
stop: int,
image_reader,
text_reader: EmbeddingReader = None,
device: str = "cpu",
) -> None:
super(PriorEmbeddingLoader).__init__()
self.text_conditioned = text_conditioned
if not self.text_conditioned:
self.text_reader = text_reader
self.image_reader = image_reader
self.batch_size = batch_size
self.start = start
self.stop = stop
self.device = device
def __iter__(self):
self.n = 0
loader_args = dict(
batch_size=self.batch_size,
start=self.start,
end=self.stop,
show_progress=False,
)
if self.text_conditioned:
self.loader = self.image_reader(**loader_args)
else:
self.loader = zip(
self.image_reader(**loader_args), self.text_reader(**loader_args)
)
return self
def __next__(self):
try:
return self.get_sample()
except StopIteration:
raise StopIteration
def get_sample(self):
"""
pre-proocess data from either reader into a common format
"""
self.n += 1
if self.text_conditioned:
image_embedding, caption = next(self.loader)
image_embedding = from_numpy(image_embedding).to(self.device)
tokenized_caption = tokenize(
caption["caption"].to_list(), truncate=True
).to(self.device)
return image_embedding, tokenized_caption
else:
(image_embedding, _), (text_embedding, _) = next(self.loader)
image_embedding = from_numpy(image_embedding).to(self.device)
text_embedding = from_numpy(text_embedding).to(self.device)
return image_embedding, text_embedding
def make_splits(
text_conditioned: bool,
batch_size: int,
num_data_points: int,
train_split: float,
eval_split: float,
device: str,
img_url: str,
meta_url: str = None,
txt_url: str = None,
):
assert img_url is not None, "Must supply some image embeddings"
if text_conditioned:
assert meta_url is not None, "Must supply metadata url if text-conditioning"
image_reader = EmbeddingReader(
embeddings_folder=img_url,
file_format="parquet_npy",
meta_columns=["caption"],
metadata_folder=meta_url,
)
# compute split points
if num_data_points > image_reader.count:
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
num_data_points = image_reader.count
train_set_size = int(train_split * num_data_points)
eval_set_size = int(eval_split * num_data_points)
eval_stop = int(train_set_size + eval_set_size)
train_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=0,
stop=train_set_size,
device=device,
)
eval_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=train_set_size,
stop=eval_stop,
device=device,
)
test_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=eval_stop,
stop=int(num_data_points),
device=device,
)
else:
assert (
txt_url is not None
), "Must supply text embedding url if not text-conditioning"
image_reader = EmbeddingReader(img_url, file_format="npy")
text_reader = EmbeddingReader(txt_url, file_format="npy")
# compute split points
if num_data_points > image_reader.count:
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
num_data_points = image_reader.count
train_set_size = int(train_split * num_data_points)
eval_set_size = int(eval_split * num_data_points)
eval_stop = int(train_set_size + eval_set_size)
train_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=0,
stop=train_set_size,
device=device,
)
eval_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=train_set_size,
stop=eval_stop,
device=device,
)
test_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=eval_stop,
stop=int(num_data_points),
device=device,
)
return train_loader, eval_loader, test_loader

View File

@@ -0,0 +1,59 @@
from pathlib import Path
import torch
from torch.utils import data
from torchvision import transforms, utils
from PIL import Image
# helpers functions
def cycle(dl):
while True:
for data in dl:
yield data
# dataset and dataloader
class Dataset(data.Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png']
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
self.transform = transforms.Compose([
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(image_size),
transforms.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
def get_images_dataloader(
folder,
*,
batch_size,
image_size,
shuffle = True,
cycle_dl = True,
pin_memory = True
):
ds = Dataset(folder, image_size)
dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
if cycle_dl:
dl = cycle(dl)
return dl

View File

@@ -7,7 +7,7 @@ def separate_weight_decayable_params(params):
def get_optimizer( def get_optimizer(
params, params,
lr = 2e-5, lr = 1e-4,
wd = 1e-2, wd = 1e-2,
betas = (0.9, 0.999), betas = (0.9, 0.999),
eps = 1e-8, eps = 1e-8,

View File

@@ -1,7 +1,7 @@
import time import time
import copy import copy
from math import ceil from math import ceil
from functools import partial from functools import partial, wraps
from collections.abc import Iterable from collections.abc import Iterable
import torch import torch
@@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.optimizer import get_optimizer
import numpy as np
# helper functions # helper functions
def exists(val): def exists(val):
@@ -45,6 +47,37 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs return kwargs_without_prefix, kwargs
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
# decorators
def cast_torch_tensor(fn):
@wraps(fn)
def inner(model, *args, **kwargs):
device = kwargs.pop('_device', next(model.parameters()).device)
cast_device = kwargs.pop('_cast_device', True)
kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
split_kwargs_index = len(all_args) - len(kwargs_keys)
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
out = fn(model, *args, **kwargs)
return out
return inner
# gradient accumulation functions # gradient accumulation functions
def split_iterable(it, split_size): def split_iterable(it, split_size):
@@ -154,8 +187,8 @@ class EMA(nn.Module):
self.online_model = model self.online_model = model
self.ema_model = copy.deepcopy(model) self.ema_model = copy.deepcopy(model)
self.update_after_step = update_after_step # only start EMA after this step number, starting at 0
self.update_every = update_every self.update_every = update_every
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
self.register_buffer('initted', torch.Tensor([False])) self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.])) self.register_buffer('step', torch.tensor([0.]))
@@ -164,14 +197,21 @@ class EMA(nn.Module):
device = self.initted.device device = self.initted.device
self.ema_model.to(device) self.ema_model.to(device)
def copy_params_from_model_to_ema(self):
self.ema_model.state_dict(self.online_model.state_dict())
def update(self): def update(self):
self.step += 1 self.step += 1
if self.step <= self.update_after_step or (self.step % self.update_every) != 0: if (self.step % self.update_every) != 0:
return
if self.step <= self.update_after_step:
self.copy_params_from_model_to_ema()
return return
if not self.initted: if not self.initted:
self.ema_model.state_dict(self.online_model.state_dict()) self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.Tensor([True])) self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model) self.update_moving_average(self.ema_model, self.online_model)
@@ -195,6 +235,16 @@ class EMA(nn.Module):
# diffusion prior trainer # diffusion prior trainer
def prior_sample_in_chunks(fn):
@wraps(fn)
def inner(self, *args, max_batch_size = None, **kwargs):
if not exists(max_batch_size):
return fn(self, *args, **kwargs)
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
return torch.cat(outputs, dim = 0)
return inner
class DiffusionPriorTrainer(nn.Module): class DiffusionPriorTrainer(nn.Module):
def __init__( def __init__(
self, self,
@@ -253,18 +303,23 @@ class DiffusionPriorTrainer(nn.Module):
self.step += 1 self.step += 1
@torch.inference_mode() @torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def p_sample_loop(self, *args, **kwargs): def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs) return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.inference_mode() @torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def sample(self, *args, **kwargs): def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs) return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@torch.inference_mode() @torch.no_grad()
def sample_batch_size(self, *args, **kwargs): def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs) return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
@cast_torch_tensor
def forward( def forward(
self, self,
*args, *args,
@@ -287,15 +342,31 @@ class DiffusionPriorTrainer(nn.Module):
# decoder trainer # decoder trainer
def decoder_sample_in_chunks(fn):
@wraps(fn)
def inner(self, *args, max_batch_size = None, **kwargs):
if not exists(max_batch_size):
return fn(self, *args, **kwargs)
if self.decoder.unconditional:
batch_size = kwargs.get('batch_size')
batch_sizes = num_to_groups(batch_size, max_batch_size)
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
else:
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
return torch.cat(outputs, dim = 0)
return inner
class DecoderTrainer(nn.Module): class DecoderTrainer(nn.Module):
def __init__( def __init__(
self, self,
decoder, decoder,
use_ema = True, use_ema = True,
lr = 2e-5, lr = 1e-4,
wd = 1e-2, wd = 1e-2,
eps = 1e-8, eps = 1e-8,
max_grad_norm = None, max_grad_norm = 0.5,
amp = False, amp = False,
**kwargs **kwargs
): ):
@@ -307,11 +378,6 @@ class DecoderTrainer(nn.Module):
self.num_unets = len(self.decoder.unets) self.num_unets = len(self.decoder.unets)
self.use_ema = use_ema self.use_ema = use_ema
if use_ema:
has_lazy_linear = any([type(module) == nn.LazyLinear for module in decoder.modules()])
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
self.ema_unets = nn.ModuleList([]) self.ema_unets = nn.ModuleList([])
self.amp = amp self.amp = amp
@@ -354,8 +420,11 @@ class DecoderTrainer(nn.Module):
scaler = getattr(self, f'scaler{index}') scaler = getattr(self, f'scaler{index}')
return scaler.scale(loss) return scaler.scale(loss)
def update(self, unet_number): def update(self, unet_number = None):
assert 1 <= unet_number <= self.num_unets if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
index = unet_number - 1 index = unet_number - 1
unet = self.decoder.unets[index] unet = self.decoder.unets[index]
@@ -377,15 +446,18 @@ class DecoderTrainer(nn.Module):
self.step += 1 self.step += 1
@torch.no_grad() @torch.no_grad()
@cast_torch_tensor
@decoder_sample_in_chunks
def sample(self, *args, **kwargs): def sample(self, *args, **kwargs):
if self.use_ema: if kwargs.pop('use_non_ema', False) or not self.use_ema:
trainable_unets = self.decoder.unets return self.decoder.sample(*args, **kwargs)
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = self.decoder.sample(*args, **kwargs) output = self.decoder.sample(*args, **kwargs)
if self.use_ema: self.decoder.unets = trainable_unets # restore original training unets
self.decoder.unets = trainable_unets # restore original training unets
# cast the ema_model unets back to original device # cast the ema_model unets back to original device
for ema in self.ema_unets: for ema in self.ema_unets:
@@ -393,13 +465,17 @@ class DecoderTrainer(nn.Module):
return output return output
@cast_torch_tensor
def forward( def forward(
self, self,
*args, *args,
unet_number, unet_number = None,
max_batch_size = None, max_batch_size = None,
**kwargs **kwargs
): ):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
total_loss = 0. total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):

BIN
samples/oxford.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 985 KiB

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.32', version = '0.3.2',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',
@@ -30,6 +30,7 @@ setup(
'einops-exts>=0.0.3', 'einops-exts>=0.0.3',
'embedding-reader', 'embedding-reader',
'kornia>=0.5.4', 'kornia>=0.5.4',
'numpy',
'pillow', 'pillow',
'resize-right>=0.0.2', 'resize-right>=0.0.2',
'rotary-embedding-torch', 'rotary-embedding-torch',

View File

@@ -5,10 +5,13 @@ import time
import numpy as np import numpy as np
import torch import torch
import clip
from torch import nn from torch import nn
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork from dalle2_pytorch.dataloaders import make_splits
from dalle2_pytorch.train import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from embedding_reader import EmbeddingReader from embedding_reader import EmbeddingReader
@@ -17,8 +20,7 @@ from tqdm import tqdm
# constants # constants
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
tracker = WandbTracker() tracker = WandbTracker()
@@ -36,81 +38,106 @@ class Timer:
def elapsed(self): def elapsed(self):
return time.time() - self.last_time return time.time() - self.last_time
# functions # functions
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"): def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
total_loss = 0. total_loss = 0.
total_samples = 0. total_samples = 0.
for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end), for image_embeddings, text_data in tqdm(dataloader):
text_reader(batch_size=batch_size, start=start, end=end)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device) batches = image_embeddings.shape[0]
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
batches = emb_images_tensor.shape[0] input_args = dict(image_embed=image_embeddings)
if text_conditioned:
input_args = dict(**input_args, text = text_data)
else:
input_args = dict(**input_args, text_embed=text_data)
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor) loss = model(**input_args)
total_loss += loss.item() * batches total_loss += loss * batches
total_samples += batches total_samples += batches
avg_loss = (total_loss / total_samples) avg_loss = (total_loss / total_samples)
tracker.log({f'{phase} {loss_type}': avg_loss}) tracker.log({f'{phase} {loss_type}': avg_loss})
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device): def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
diffusion_prior.eval() diffusion_prior.eval()
cos = nn.CosineSimilarity(dim=1, eps=1e-6) cos = nn.CosineSimilarity(dim=1, eps=1e-6)
tstart = train_set_size for test_image_embeddings, text_data in tqdm(dataloader):
tend = train_set_size+NUM_TEST_EMBEDDINGS
# we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned:
text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text(
text_data)
text_cond = dict(text_embed=text_embedding,
text_encodings=text_encodings, mask=text_mask)
else:
text_embedding = text_data
text_cond = dict(text_embed=text_embedding)
# make a copy of the text embeddings for shuffling
text_embed_shuffled = text_embedding.clone()
# roll the text to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / \
text_embed_shuffled.norm(dim=1, keepdim=True)
if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx]
text_mask_shuffled = text_mask[rolled_idx]
else:
text_encodings_shuffled = None
text_mask_shuffled = None
text_cond_shuffled = dict(text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled, mask=text_mask_shuffled)
for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend),
image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)):
# make a copy of the text embeddings for shuffling
text_embed = torch.tensor(embt[0]).to(device)
text_embed_shuffled = text_embed.clone()
# roll the text embeddings to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / \
text_embed_shuffled.norm(dim=1, keepdim=True)
test_text_shuffled_cond = dict(text_embed=text_embed_shuffled)
# prepare the text embedding # prepare the text embedding
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True) text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
test_text_cond = dict(text_embed=text_embed)
# prepare image embeddings # prepare image embeddings
test_image_embeddings = torch.tensor(embi[0]).to(device) test_image_embeddings = test_image_embeddings / \
test_image_embeddings = test_image_embeddings / \ test_image_embeddings.norm(dim=1, keepdim=True)
test_image_embeddings.norm(dim=1, keepdim=True)
# predict on the unshuffled text embeddings # predict on the unshuffled text embeddings
predicted_image_embeddings = diffusion_prior.p_sample_loop( predicted_image_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond) test_image_embeddings.shape, text_cond)
predicted_image_embeddings = predicted_image_embeddings / \ predicted_image_embeddings = predicted_image_embeddings / \
predicted_image_embeddings.norm(dim=1, keepdim=True) predicted_image_embeddings.norm(dim=1, keepdim=True)
# predict on the shuffled embeddings # predict on the shuffled embeddings
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop( predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond) test_image_embeddings.shape, text_cond_shuffled)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \ predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
predicted_unrelated_embeddings.norm(dim=1, keepdim=True) predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
# calculate similarities # calculate similarities
original_similarity = cos( original_similarity = cos(
text_embed, test_image_embeddings).cpu().numpy() text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos( predicted_similarity = cos(
text_embed, predicted_image_embeddings).cpu().numpy() text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = cos( unrelated_similarity = cos(
text_embed, predicted_unrelated_embeddings).cpu().numpy() text_embed, predicted_unrelated_embeddings).cpu().numpy()
predicted_img_similarity = cos( predicted_img_similarity = cos(
test_image_embeddings, predicted_image_embeddings).cpu().numpy() test_image_embeddings, predicted_image_embeddings).cpu().numpy()
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity), tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity), "CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity), "CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity), "CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)}) "Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
@click.command() @click.command()
@click.option("--wandb-entity", default="laion") @click.option("--wandb-entity", default="laion")
@click.option("--wandb-project", default="diffusion-prior") @click.option("--wandb-project", default="diffusion-prior")
@@ -118,29 +145,32 @@ def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,N
@click.option("--wandb-arch", default="DiffusionPrior") @click.option("--wandb-arch", default="DiffusionPrior")
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") @click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") @click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/")
@click.option("--learning-rate", default=1.1e-4) @click.option("--learning-rate", default=1.1e-4)
@click.option("--weight-decay", default=6.02e-2) @click.option("--weight-decay", default=6.02e-2)
@click.option("--dropout", default=5e-2) @click.option("--dropout", default=5e-2)
@click.option("--max-grad-norm", default=0.5) @click.option("--max-grad-norm", default=0.5)
@click.option("--batch-size", default=10**4) @click.option("--num-data-points", default=250e6)
@click.option("--batch-size", default=320)
@click.option("--num-epochs", default=5) @click.option("--num-epochs", default=5)
@click.option("--image-embed-dim", default=768) @click.option("--image-embed-dim", default=768)
@click.option("--train-percent", default=0.7) @click.option("--train-percent", default=0.9)
@click.option("--val-percent", default=0.2) @click.option("--val-percent", default=1e-7)
@click.option("--test-percent", default=0.1) @click.option("--test-percent", default=0.0999999)
@click.option("--dpn-depth", default=6) @click.option("--dpn-depth", default=12)
@click.option("--dpn-dim-head", default=64) @click.option("--dpn-dim-head", default=64)
@click.option("--dpn-heads", default=8) @click.option("--dpn-heads", default=12)
@click.option("--dp-condition-on-text-encodings", default=False) @click.option("--dp-condition-on-text-encodings", default=True)
@click.option("--dp-timesteps", default=100) @click.option("--dp-timesteps", default=1000)
@click.option("--dp-normformer", default=False) @click.option("--dp-normformer", default=True)
@click.option("--dp-cond-drop-prob", default=0.1) @click.option("--dp-cond-drop-prob", default=0.1)
@click.option("--dp-loss-type", default="l2") @click.option("--dp-loss-type", default="l2")
@click.option("--clip", default=None) @click.option("--clip", default="ViT-L/14")
@click.option("--amp", default=False) @click.option("--amp", default=False)
@click.option("--save-interval", default=30) @click.option("--save-interval", default=120)
@click.option("--save-path", default="./diffusion_prior_checkpoints") @click.option("--save-path", default="./diffusion_prior_checkpoints")
@click.option("--pretrained-model-path", default=None) @click.option("--pretrained-model-path", default=None)
@click.option("--gpu-device", default=0)
def train( def train(
wandb_entity, wandb_entity,
wandb_project, wandb_project,
@@ -148,10 +178,12 @@ def train(
wandb_arch, wandb_arch,
image_embed_url, image_embed_url,
text_embed_url, text_embed_url,
meta_url,
learning_rate, learning_rate,
weight_decay, weight_decay,
dropout, dropout,
max_grad_norm, max_grad_norm,
num_data_points,
batch_size, batch_size,
num_epochs, num_epochs,
image_embed_dim, image_embed_dim,
@@ -170,7 +202,8 @@ def train(
amp, amp,
save_interval, save_interval,
save_path, save_path,
pretrained_model_path pretrained_model_path,
gpu_device
): ):
config = { config = {
"learning_rate": learning_rate, "learning_rate": learning_rate,
@@ -197,7 +230,7 @@ def train(
# Check if DPRIOR_PATH exists(saved model path) # Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = args.pretrained_model_path DPRIOR_PATH = pretrained_model_path
RESUME = exists(DPRIOR_PATH) RESUME = exists(DPRIOR_PATH)
if not RESUME: if not RESUME:
@@ -211,7 +244,7 @@ def train(
has_cuda = torch.cuda.is_available() has_cuda = torch.cuda.is_available()
if has_cuda: if has_cuda:
device = torch.device("cuda:0") device = torch.device(f"cuda:{gpu_device}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
# Training loop # Training loop
@@ -227,11 +260,17 @@ def train(
normformer = dp_normformer normformer = dp_normformer
) )
# Load clip model if text-conditioning
if dp_condition_on_text_encodings:
clip_adapter = OpenAIClipAdapter(clip)
else:
clip_adapter = None
# diffusion prior with text embeddings and image embeddings pre-computed # diffusion prior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior( diffusion_prior = DiffusionPrior(
net = prior_network, net = prior_network,
clip = clip, clip = clip_adapter,
image_embed_dim = image_embed_dim, image_embed_dim = image_embed_dim,
timesteps = dp_timesteps, timesteps = dp_timesteps,
cond_drop_prob = dp_cond_drop_prob, cond_drop_prob = dp_cond_drop_prob,
@@ -265,33 +304,37 @@ def train(
Path(save_path).mkdir(exist_ok = True, parents = True) Path(save_path).mkdir(exist_ok = True, parents = True)
# Get image and text embeddings from the servers # Utilize wrapper to abstract away loader logic
print_ribbon("Downloading Embeddings")
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
print_ribbon("Downloading embeddings - image and text") if dp_condition_on_text_encodings:
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy") loader_args = dict(**loader_args, meta_url=meta_url)
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy") else:
num_data_points = text_reader.count loader_args = dict(**loader_args, txt_url=text_embed_url)
train_loader, eval_loader, test_loader = make_splits(**loader_args)
### Training code ### ### Training code ###
step = 1
timer = Timer() timer = Timer()
epochs = num_epochs epochs = num_epochs
train_set_size = int(train_percent*num_data_points)
val_set_size = int(val_percent*num_data_points)
eval_start = train_set_size
for _ in range(epochs): for _ in range(epochs):
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size), for image, text in tqdm(train_loader):
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
trainer.train()
emb_images_tensor = torch.tensor(emb_images[0]).to(device) diffusion_prior.train()
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
input_args = dict(image_embed=image)
if dp_condition_on_text_encodings:
input_args = dict(**input_args, text = text)
else:
input_args = dict(**input_args, text_embed=text)
loss = trainer(text_embed = emb_text_tensor, image_embed = emb_images_tensor) loss = trainer(**input_args)
# Samples per second # Samples per second
@@ -310,37 +353,23 @@ def train(
image_embed_dim) image_embed_dim)
# Log to wandb # Log to wandb
tracker.log({"Training loss": loss.item(), tracker.log({"Training loss": loss,
"Steps": step, "Steps": step,
"Samples per second": samples_per_sec}) "Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed) # Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
# Use NUM_TEST_EMBEDDINGS samples from the test set each time # Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model # Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0: if(step % REPORT_METRICS_EVERY) == 0:
report_cosine_sims(diffusion_prior, report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
image_reader,
text_reader,
train_set_size,
NUM_TEST_EMBEDDINGS,
device)
### Evaluate model(validation run) ### ### Evaluate model(validation run) ###
eval_model(diffusion_prior, eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
device,
image_reader,
text_reader,
eval_start,
eval_start+NUM_TEST_EMBEDDINGS,
NUM_TEST_EMBEDDINGS,
dp_loss_type,
phase="Validation")
step += 1
trainer.update() trainer.update()
### Test run ### ### Test run ###
test_set_size = int(test_percent*train_set_size) eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test")
start = train_set_size+val_set_size
end = num_data_points
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
if __name__ == "__main__": if __name__ == "__main__":
train() train()