Compare commits

...

11 Commits

4 changed files with 130 additions and 78 deletions

View File

@@ -647,11 +647,12 @@ Once built, images will be saved to the same directory the command is invoked
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435 - [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms - [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion - [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in - [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference - [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in tools to train vqgan-vae - [ ] bring in tools to train vqgan-vae

View File

@@ -7,6 +7,7 @@ from contextlib import contextmanager
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
import torchvision.transforms as T
from einops import rearrange, repeat from einops import rearrange, repeat
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
@@ -89,6 +90,59 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://
return F.interpolate(t, size = shape, mode = mode, align_corners = False) return F.interpolate(t, size = shape, mode = mode, align_corners = False)
# clip related adapters
class BaseClipAdapter(nn.Module):
def __init__(self, clip):
super().__init__()
self.clip = clip
@property
def dim_latent(self):
raise NotImplementedError
@property
def image_size(self):
raise NotImplementedError
@property
def image_channels(self):
raise NotImplementedError
def embed_text(self, text):
raise NotImplementedError
def embed_image(self, image):
raise NotImplementedError
class XClipAdapter(BaseClipAdapter):
@property
def dim_latent(self):
return self.clip.dim_latent
@property
def image_size(self):
return self.clip.image_size
@property
def image_channels(self):
return self.clip.image_channels
@torch.no_grad()
def embed_text(self, text):
encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
return l2norm(text_embed), text_encodings
@torch.no_grad()
def embed_image(self, image):
image = resize_image_to(image, self.image_size)
encoder_output = self.clip.visual_transformer(image)
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed), image_encodings
# classifier free guidance functions # classifier free guidance functions
def prob_mask_like(shape, prob, device): def prob_mask_like(shape, prob, device):
@@ -169,7 +223,18 @@ class BaseGaussianDiffusion(nn.Module):
timesteps, = betas.shape timesteps, = betas.shape
self.num_timesteps = int(timesteps) self.num_timesteps = int(timesteps)
if loss_type == 'l1':
loss_fn = F.l1_loss
elif loss_type == 'l2':
loss_fn = F.mse_loss
elif loss_type == 'huber':
loss_fn = F.smooth_l1_loss
else:
raise NotImplementedError()
self.loss_type = loss_type self.loss_type = loss_type
self.loss_fn = loss_fn
self.register_buffer('betas', betas) self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod) self.register_buffer('alphas_cumprod', alphas_cumprod)
@@ -593,7 +658,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
) )
if exists(clip): if exists(clip):
assert isinstance(clip, CLIP) if isinstance(clip, CLIP):
clip = XClipAdapter(clip)
assert isinstance(clip, BaseClipAdapter)
freeze_model_and_make_eval_(clip) freeze_model_and_make_eval_(clip)
self.clip = clip self.clip = clip
else: else:
@@ -610,29 +678,6 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.predict_x_start = predict_x_start self.predict_x_start = predict_x_start
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
@torch.no_grad()
def get_image_embed(self, image):
assert exists(self.clip)
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed)
@torch.no_grad()
def get_text_cond(self, text):
assert exists(self.clip)
text_encodings = self.clip.text_transformer(text)
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
text_embed = l2norm(text_embed)
if not self.condition_on_text_encodings:
return dict(text_embed = text_embed)
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond) pred = self.net(x, t, **text_cond)
@@ -669,29 +714,21 @@ class DiffusionPrior(BaseGaussianDiffusion):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond) img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img return img
def p_losses(self, image_embed, t, text_cond, noise = None): def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed)) noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise) image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise)
x_recon = self.net( pred = self.net(
image_embed_noisy, image_embed_noisy,
t, times,
cond_drop_prob = self.cond_drop_prob, cond_drop_prob = self.cond_drop_prob,
**text_cond **text_cond
) )
to_predict = noise if not self.predict_x_start else image_embed target = noise if not self.predict_x_start else image_embed
if self.loss_type == 'l1':
loss = F.l1_loss(to_predict, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(to_predict, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(to_predict, x_recon)
else:
raise NotImplementedError()
loss = self.loss_fn(pred, target)
return loss return loss
@torch.no_grad() @torch.no_grad()
@@ -704,7 +741,12 @@ class DiffusionPrior(BaseGaussianDiffusion):
batch_size = text.shape[0] batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim image_embed_dim = self.image_embed_dim
text_cond = self.get_text_cond(text) text_embed, text_encodings = self.clip.embed_text(text)
text_cond = dict(text_embed = text_embed)
if self.condition_on_text_encodings:
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text != 0}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embed'] text_embeds = text_cond['text_embed']
@@ -736,18 +778,19 @@ class DiffusionPrior(BaseGaussianDiffusion):
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization' assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
if exists(image): if exists(image):
image_embed = self.get_image_embed(image) image_embed, _ = self.clip.embed_image(image)
# calculate text conditionings, based on what is passed in # calculate text conditionings, based on what is passed in
if exists(text): if exists(text):
text_cond = self.get_text_cond(text) text_embed, text_encodings = self.clip.embed_text(text)
else: text_mask = text != 0
text_cond = dict(
text_embed = text_embed, text_cond = dict(text_embed = text_embed)
text_encodings = text_encodings,
mask = text_mask if self.condition_on_text_encodings:
) assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
# timestep conditioning from ddpm # timestep conditioning from ddpm
@@ -756,8 +799,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
# calculate forward loss # calculate forward loss
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs) return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
return loss
# decoder # decoder
@@ -1027,13 +1069,14 @@ class Unet(nn.Module):
self, self,
*, *,
lowres_cond, lowres_cond,
channels channels,
cond_on_image_embeds
): ):
if lowres_cond == self.lowres_cond and channels == self.channels: if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
return self return self
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels} updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
return self.__class__(**updated_kwargs) return self.__class__(**{**self._locals, **updated_kwargs})
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
@@ -1170,7 +1213,7 @@ class LowresConditioner(nn.Module):
target_image_size = cast_tuple(target_image_size, 2) target_image_size = cast_tuple(target_image_size, 2)
if self.training and self.downsample_first and exists(downsample_image_size): if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode) cond_fmap = resize_image_to(cond_fmap, downsample_image_size, mode = self.cond_upsample_mode)
if self.training: if self.training:
# when training, blur the low resolution conditional image # when training, blur the low resolution conditional image
@@ -1208,8 +1251,12 @@ class Decoder(BaseGaussianDiffusion):
loss_type = loss_type loss_type = loss_type
) )
assert isinstance(clip, CLIP) if isinstance(clip, CLIP):
clip = XClipAdapter(clip)
freeze_model_and_make_eval_(clip) freeze_model_and_make_eval_(clip)
assert isinstance(clip, BaseClipAdapter)
self.clip = clip self.clip = clip
self.clip_image_size = clip.image_size self.clip_image_size = clip.image_size
self.channels = clip.image_channels self.channels = clip.image_channels
@@ -1236,6 +1283,7 @@ class Decoder(BaseGaussianDiffusion):
one_unet = one_unet.cast_model_parameters( one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first, lowres_cond = not is_first,
cond_on_image_embeds = is_first,
channels = unet_channels channels = unet_channels
) )
@@ -1290,10 +1338,6 @@ class Decoder(BaseGaussianDiffusion):
yield yield
unet.cpu() unet.cpu()
@torch.no_grad()
def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:]
@torch.no_grad() @torch.no_grad()
def get_image_embed(self, image): def get_image_embed(self, image):
@@ -1347,14 +1391,14 @@ class Decoder(BaseGaussianDiffusion):
return img return img
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None): def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise) x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
x_recon = unet( pred = unet(
x_noisy, x_noisy,
t, times,
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
@@ -1363,15 +1407,7 @@ class Decoder(BaseGaussianDiffusion):
target = noise if not predict_x_start else x_start target = noise if not predict_x_start else x_start
if self.loss_type == 'l1': loss = self.loss_fn(pred, target)
loss = F.l1_loss(target, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(target, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(target, x_recon)
else:
raise NotImplementedError()
return loss return loss
@torch.no_grad() @torch.no_grad()
@@ -1379,9 +1415,12 @@ class Decoder(BaseGaussianDiffusion):
def sample(self, image_embed, text = None, cond_scale = 1.): def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0] batch_size = image_embed.shape[0]
text_encodings = self.get_text_encodings(text) if exists(text) else None text_encodings = None
if exists(text):
_, text_encodings = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
img = None img = None
@@ -1442,11 +1481,14 @@ class Decoder(BaseGaussianDiffusion):
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
if not exists(image_embed): if not exists(image_embed):
image_embed = self.get_image_embed(image) image_embed, _ = self.clip.embed_image(image)
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None text_encodings = None
if exists(text) and not exists(text_encodings):
_, text_encodings = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size) image = resize_image_to(image, target_image_size)
@@ -1479,12 +1521,15 @@ class DALLE2(nn.Module):
self.prior_num_samples = prior_num_samples self.prior_num_samples = prior_num_samples
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
self.to_pil = T.ToPILImage()
@torch.no_grad() @torch.no_grad()
@eval_decorator @eval_decorator
def forward( def forward(
self, self,
text, text,
cond_scale = 1. cond_scale = 1.,
return_pil_images = False
): ):
device = next(self.parameters()).device device = next(self.parameters()).device
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1) one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
@@ -1498,7 +1543,11 @@ class DALLE2(nn.Module):
text_cond = text if self.decoder_need_text_cond else None text_cond = text if self.decoder_need_text_cond else None
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale) images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
if return_pil_images:
images = list(map(self.to_pil, images.unbind(dim = 0)))
if one_text: if one_text:
return images[0] return images[0]
return images return images

View File

@@ -545,6 +545,7 @@ class VQGanVAE(nn.Module):
l2_recon_loss = False, l2_recon_loss = False,
use_hinge_loss = True, use_hinge_loss = True,
vgg = None, vgg = None,
vq_codebook_dim = 256,
vq_codebook_size = 512, vq_codebook_size = 512,
vq_decay = 0.8, vq_decay = 0.8,
vq_commitment_weight = 1., vq_commitment_weight = 1.,
@@ -579,6 +580,7 @@ class VQGanVAE(nn.Module):
self.vq = VQ( self.vq = VQ(
dim = self.enc_dec.encoded_dim, dim = self.enc_dec.encoded_dim,
codebook_dim = vq_codebook_dim,
codebook_size = vq_codebook_size, codebook_size = vq_codebook_size,
decay = vq_decay, decay = vq_decay,
commitment_weight = vq_commitment_weight, commitment_weight = vq_commitment_weight,

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.56', version = '0.0.65',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',