Compare commits

..

1 Commits

6 changed files with 85 additions and 339 deletions

110
README.md
View File

@@ -47,7 +47,7 @@ clip = CLIP(
use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
use_visual_ssl = True, # whether to do self supervised learning on images
use_visual_ssl = True, # whether to do self supervised learning on iages
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
text_ssl_loss_weight = 0.05, # weight for text MLM loss
@@ -110,8 +110,7 @@ decoder = Decoder(
unet = unet,
clip = clip,
timesteps = 100,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
cond_drop_prob = 0.2
).cuda()
# mock images (get a lot of this)
@@ -230,8 +229,7 @@ decoder = Decoder(
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
cond_drop_prob = 0.2
).cuda()
# mock images (get a lot of this)
@@ -350,8 +348,7 @@ decoder = Decoder(
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
cond_drop_prob = 0.2,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()
@@ -433,8 +430,8 @@ images = torch.randn(4, 3, 256, 256).cuda()
# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone
clip_image_embeds = diffusion_prior.clip.embed_image(images).image_embed
clip_text_embeds = diffusion_prior.clip.embed_text(text).text_embed
clip_image_embeds = diffusion_prior.get_image_embed(images)
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
# feed text and images into diffusion prior network
@@ -498,95 +495,6 @@ loss.backward()
# now the diffusion prior can generate image embeddings from the text embeddings
```
## OpenAI CLIP
Although there is the possibility they are using an unreleased, more powerful CLIP, you can use one of the released ones, if you do not wish to train your own CLIP from scratch. This will also allow the community to more quickly validate the conclusions of the paper.
To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it into the `DiffusionPrior` or `Decoder` like so
```python
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
# openai pretrained clip - defaults to ViT/B-32
clip = OpenAIClipAdapter()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# prior networks (with transformer)
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
loss = diffusion_prior(text, images)
loss.backward()
# do above for many steps ...
# decoder (with unet)
unet1 = Unet(
dim = 128,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()
for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
dalle2 = DALLE2(
prior = diffusion_prior,
decoder = decoder
)
images = dalle2(
['a butterfly trying to escape a tornado'],
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)
# save your image (in this example, of size 256x256)
```
Now you'll just have to worry about training the Prior and the Decoder!
## Experimental
### DALL-E2 with Latent Diffusion
@@ -620,7 +528,7 @@ clip = CLIP(
# 3 unets for the decoder (a la cascading DDPM)
# first two unets are doing latent diffusion
# vqgan-vae must be trained beforehand
# vqgan-vae must be trained before hand
vae1 = VQGanVAE(
dim = 32,
@@ -673,8 +581,7 @@ decoder = Decoder(
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
timesteps = 100,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
cond_drop_prob = 0.2
).cuda()
# mock images (get a lot of this)
@@ -744,7 +651,6 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference

View File

@@ -1,5 +1,4 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP

View File

@@ -3,7 +3,6 @@ from tqdm import tqdm
from inspect import isfunction
from functools import partial
from contextlib import contextmanager
from collections import namedtuple
import torch
import torch.nn.functional as F
@@ -91,21 +90,8 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://
return F.interpolate(t, size = shape, mode = mode, align_corners = False)
# image normalization functions
# ddpms expect images to be in the range of -1 to 1
# but CLIP may otherwise
def normalize_img(img):
return img * 2 - 1
def unnormalize_img(normed_img):
return (normed_img + 1) * 0.5
# clip related adapters
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask'])
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
class BaseClipAdapter(nn.Module):
def __init__(self, clip):
super().__init__()
@@ -123,10 +109,6 @@ class BaseClipAdapter(nn.Module):
def image_channels(self):
raise NotImplementedError
@property
def max_text_len(self):
raise NotImplementedError
def embed_text(self, text):
raise NotImplementedError
@@ -146,18 +128,12 @@ class XClipAdapter(BaseClipAdapter):
def image_channels(self):
return self.clip.image_channels
@property
def max_text_len(self):
return self.clip.text_seq_len
@torch.no_grad()
def embed_text(self, text):
text = text[..., :self.max_text_len]
text_mask = text != 0
encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask)
return l2norm(text_embed), text_encodings
@torch.no_grad()
def embed_image(self, image):
@@ -165,69 +141,7 @@ class XClipAdapter(BaseClipAdapter):
encoder_output = self.clip.visual_transformer(image)
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
image_embed = self.clip.to_visual_latent(image_cls)
return EmbeddedImage(l2norm(image_embed), image_encodings)
class OpenAIClipAdapter(BaseClipAdapter):
def __init__(
self,
name = 'ViT-B/32'
):
import clip
openai_clip, preprocess = clip.load(name)
super().__init__(openai_clip)
text_attention_final = self.find_layer('ln_final')
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1]
self.cleared = False
def find_layer(self, layer):
modules = dict([*self.clip.named_modules()])
return modules.get(layer, None)
def clear(self):
if self.cleared:
return
self.handle()
def _hook(self, _, inputs, outputs):
self.text_encodings = outputs
@property
def dim_latent(self):
return 512
@property
def image_size(self):
return self.clip.visual.input_resolution
@property
def image_channels(self):
return 3
@property
def max_text_len(self):
return self.clip.context_length
@torch.no_grad()
def embed_text(self, text):
text = text[..., :self.max_text_len]
text_mask = text != 0
assert not self.cleared
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
del self.text_encodings
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
@torch.no_grad()
def embed_image(self, image):
assert not self.cleared
image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image))
image_embed = self.clip.encode_image(image)
return EmbeddedImage(image_embed.float(), None)
return l2norm(image_embed), image_encodings
# classifier free guidance functions
@@ -309,18 +223,7 @@ class BaseGaussianDiffusion(nn.Module):
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
if loss_type == 'l1':
loss_fn = F.l1_loss
elif loss_type == 'l2':
loss_fn = F.mse_loss
elif loss_type == 'huber':
loss_fn = F.smooth_l1_loss
else:
raise NotImplementedError()
self.loss_type = loss_type
self.loss_fn = loss_fn
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
@@ -684,14 +587,14 @@ class DiffusionPriorNetwork(nn.Module):
# classifier free guidance
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
keep_mask = rearrange(keep_mask, 'b -> b 1')
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
mask &= keep_mask
mask &= cond_prob_mask
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
mask = torch.cat((mask, keep_mask), dim = 1)
mask = torch.cat((mask, cond_prob_mask), dim = 1)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
@@ -736,7 +639,6 @@ class DiffusionPrior(BaseGaussianDiffusion):
predict_x_start = True,
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False
):
super().__init__(
beta_schedule = beta_schedule,
@@ -745,12 +647,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
)
if exists(clip):
if isinstance(clip, CLIP):
clip = XClipAdapter(clip)
assert isinstance(clip, BaseClipAdapter)
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
self.clip = clip
self.clip = XClipAdapter(clip)
else:
assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'
self.clip = None
@@ -765,9 +664,6 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.predict_x_start = predict_x_start
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
# whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond)
@@ -781,9 +677,6 @@ class DiffusionPrior(BaseGaussianDiffusion):
if clip_denoised and not self.predict_x_start:
x_recon.clamp_(-1., 1.)
if self.predict_x_start and self.sampling_clamp_l2norm:
x_recon = l2norm(x_recon)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@@ -807,21 +700,29 @@ class DiffusionPrior(BaseGaussianDiffusion):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
def p_losses(self, image_embed, times, text_cond, noise = None):
def p_losses(self, image_embed, t, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise)
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise)
pred = self.net(
x_recon = self.net(
image_embed_noisy,
times,
t,
cond_drop_prob = self.cond_drop_prob,
**text_cond
)
target = noise if not self.predict_x_start else image_embed
to_predict = noise if not self.predict_x_start else image_embed
if self.loss_type == 'l1':
loss = F.l1_loss(to_predict, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(to_predict, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(to_predict, x_recon)
else:
raise NotImplementedError()
loss = self.loss_fn(pred, target)
return loss
@torch.no_grad()
@@ -834,7 +735,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
text_embed, text_encodings = self.clip.embed_text(text)
text_cond = dict(text_embed = text_embed)
@@ -876,12 +777,12 @@ class DiffusionPrior(BaseGaussianDiffusion):
# calculate text conditionings, based on what is passed in
if exists(text):
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
text_embed, text_encodings = self.clip.embed_text(text)
text_mask = text != 0
text_cond = dict(text_embed = text_embed)
if self.condition_on_text_encodings:
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
# timestep conditioning from ddpm
@@ -891,7 +792,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
# calculate forward loss
return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
return loss
# decoder
@@ -1108,8 +1010,6 @@ class Unet(nn.Module):
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
# attention related params
@@ -1163,14 +1063,13 @@ class Unet(nn.Module):
self,
*,
lowres_cond,
channels,
cond_on_image_embeds
channels
):
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
if lowres_cond == self.lowres_cond and channels == self.channels:
return self
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
return self.__class__(**{**self._locals, **updated_kwargs})
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
return self.__class__(**updated_kwargs)
def forward_with_cond_scale(
self,
@@ -1183,7 +1082,7 @@ class Unet(nn.Module):
if cond_scale == 1:
return logits
null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs)
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
@@ -1194,9 +1093,7 @@ class Unet(nn.Module):
image_embed,
lowres_cond_img = None,
text_encodings = None,
text_mask = None,
image_cond_drop_prob = 0.,
text_cond_drop_prob = 0.,
cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
):
@@ -1215,10 +1112,8 @@ class Unet(nn.Module):
# conditional dropout
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
# mask out image embedding depending on condition dropout
# for classifier free guidance
@@ -1229,7 +1124,7 @@ class Unet(nn.Module):
image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where(
image_keep_mask,
cond_prob_mask,
image_tokens,
self.null_image_embed
)
@@ -1240,25 +1135,10 @@ class Unet(nn.Module):
if exists(text_encodings) and self.cond_on_text_encodings:
text_tokens = self.text_to_cond(text_encodings)
text_tokens = text_tokens[:, :self.max_text_len]
text_tokens_len = text_tokens.shape[1]
remainder = self.max_text_len - text_tokens_len
if remainder > 0:
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
if exists(text_mask):
if remainder > 0:
text_mask = F.pad(text_mask, (0, remainder), value = False)
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask = text_mask & text_keep_mask
text_tokens = torch.where(
text_keep_mask,
cond_prob_mask,
text_tokens,
self.null_text_embed
self.null_text_embed[:, :text_tokens.shape[1]]
)
# main conditioning tokens (c)
@@ -1326,7 +1206,7 @@ class LowresConditioner(nn.Module):
target_image_size = cast_tuple(target_image_size, 2)
if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, mode = self.cond_upsample_mode)
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
if self.training:
# when training, blur the low resolution conditional image
@@ -1346,8 +1226,7 @@ class Decoder(BaseGaussianDiffusion):
clip,
vae = tuple(),
timesteps = 1000,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
cond_drop_prob = 0.2,
loss_type = 'l1',
beta_schedule = 'cosine',
predict_x_start = False,
@@ -1358,8 +1237,6 @@ class Decoder(BaseGaussianDiffusion):
blur_sigma = 0.1, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
clip_denoised = True,
clip_x_start = True
):
super().__init__(
beta_schedule = beta_schedule,
@@ -1371,8 +1248,6 @@ class Decoder(BaseGaussianDiffusion):
clip = XClipAdapter(clip)
freeze_model_and_make_eval_(clip)
assert isinstance(clip, BaseClipAdapter)
self.clip = clip
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
@@ -1399,7 +1274,6 @@ class Decoder(BaseGaussianDiffusion):
one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first,
cond_on_image_embeds = is_first,
channels = unet_channels
)
@@ -1433,13 +1307,7 @@ class Decoder(BaseGaussianDiffusion):
# classifier free guidance
self.image_cond_drop_prob = image_cond_drop_prob
self.text_cond_drop_prob = text_cond_drop_prob
# whether to clip when sampling
self.clip_denoised = clip_denoised
self.clip_x_start = clip_x_start
self.cond_drop_prob = cond_drop_prob
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
@@ -1463,34 +1331,37 @@ class Decoder(BaseGaussianDiffusion):
@torch.no_grad()
def get_image_embed(self, image):
image_embed, _ = self.clip.embed_image(image)
return image_embed
image = resize_image_to(image, self.clip_image_size)
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed)
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
if predict_x_start:
x_recon = pred
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised:
if clip_denoised and not predict_x_start:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
device = self.betas.device
b = shape[0]
@@ -1503,34 +1374,38 @@ class Decoder(BaseGaussianDiffusion):
torch.full((b,), i, device = device, dtype = torch.long),
image_embed = image_embed,
text_encodings = text_encodings,
text_mask = text_mask,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start,
clip_denoised = clip_denoised
predict_x_start = predict_x_start
)
return img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None):
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
pred = unet(
x_recon = unet(
x_noisy,
times,
t,
image_embed = image_embed,
text_encodings = text_encodings,
text_mask = text_mask,
lowres_cond_img = lowres_cond_img,
image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob,
cond_drop_prob = self.cond_drop_prob
)
target = noise if not predict_x_start else x_start
loss = self.loss_fn(pred, target)
if self.loss_type == 'l1':
loss = F.l1_loss(target, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(target, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(target, x_recon)
else:
raise NotImplementedError()
return loss
@torch.no_grad()
@@ -1538,12 +1413,11 @@ class Decoder(BaseGaussianDiffusion):
def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0]
text_encodings = text_mask = None
text_encodings = None
if exists(text):
_, text_encodings, text_mask = self.clip.embed_text(text)
_, text_encodings = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
img = None
@@ -1558,7 +1432,6 @@ class Decoder(BaseGaussianDiffusion):
if unet.lowres_cond:
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
is_latent_diffusion = isinstance(vae, VQGanVAE)
image_size = vae.get_encoded_fmap_size(image_size)
shape = (batch_size, vae.encoded_dim, image_size, image_size)
@@ -1570,10 +1443,8 @@ class Decoder(BaseGaussianDiffusion):
shape,
image_embed = image_embed,
text_encodings = text_encodings,
text_mask = text_mask,
cond_scale = cond_scale,
predict_x_start = predict_x_start,
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img
)
@@ -1609,12 +1480,11 @@ class Decoder(BaseGaussianDiffusion):
if not exists(image_embed):
image_embed, _ = self.clip.embed_image(image)
text_encodings = text_mask = None
text_encodings = None
if exists(text) and not exists(text_encodings):
_, text_encodings, text_mask = self.clip.embed_text(text)
_, text_encodings = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size)
@@ -1626,7 +1496,7 @@ class Decoder(BaseGaussianDiffusion):
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
# main class
@@ -1670,9 +1540,12 @@ class DALLE2(nn.Module):
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
if return_pil_images:
# do some magic - if the user passed in a string text, or a list of strings
# assume they do not know anything about tensors and return PIL Image(s)
images = list(map(self.to_pil, images.unbind(dim = 0)))
if one_text:
return images[0]
return images

View File

@@ -1,29 +0,0 @@
from torch.optim import AdamW, Adam
def separate_weight_decayable_params(params):
no_wd_params = set([param for param in params if param.ndim < 2])
wd_params = set(params) - no_wd_params
return wd_params, no_wd_params
def get_optimizer(
params,
lr = 3e-4,
wd = 1e-2,
betas = (0.9, 0.999),
filter_by_requires_grad = False
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
if wd == 0:
return Adam(params, lr = lr, betas = betas)
params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)
param_groups = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)

View File

@@ -545,7 +545,6 @@ class VQGanVAE(nn.Module):
l2_recon_loss = False,
use_hinge_loss = True,
vgg = None,
vq_codebook_dim = 256,
vq_codebook_size = 512,
vq_decay = 0.8,
vq_commitment_weight = 1.,
@@ -580,7 +579,6 @@ class VQGanVAE(nn.Module):
self.vq = VQ(
dim = self.enc_dec.encoded_dim,
codebook_dim = vq_codebook_dim,
codebook_size = vq_codebook_size,
decay = vq_decay,
commitment_weight = vq_commitment_weight,

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.76',
version = '0.0.59',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -23,7 +23,6 @@ setup(
],
install_requires=[
'click',
'clip-anytorch',
'einops>=0.4',
'einops-exts>=0.0.3',
'kornia>=0.5.4',
@@ -32,7 +31,7 @@ setup(
'torchvision',
'tqdm',
'vector-quantize-pytorch',
'x-clip>=0.5.1',
'x-clip>=0.4.4',
'youtokentome'
],
classifiers=[