fix all issues with text encodings conditioning in the decoder, using null padding tokens technique from dalle v1

This commit is contained in:
Phil Wang
2022-04-30 09:13:34 -07:00
parent f19c99ecb0
commit 1c1e508369
2 changed files with 33 additions and 13 deletions

View File

@@ -1101,6 +1101,8 @@ class Unet(nn.Module):
# for classifier free guidance # for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
# attention related params # attention related params
@@ -1185,6 +1187,7 @@ class Unet(nn.Module):
image_embed, image_embed,
lowres_cond_img = None, lowres_cond_img = None,
text_encodings = None, text_encodings = None,
text_mask = None,
image_cond_drop_prob = 0., image_cond_drop_prob = 0.,
text_cond_drop_prob = 0., text_cond_drop_prob = 0.,
blur_sigma = None, blur_sigma = None,
@@ -1229,11 +1232,25 @@ class Unet(nn.Module):
text_tokens = None text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings: if exists(text_encodings) and self.cond_on_text_encodings:
text_tokens = text_tokens[:, :self.max_text_len]
text_tokens_len = text_tokens.shape[1]
remainder = self.max_text_len - text_tokens_len
if remainder > 0:
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
if exists(text_mask):
if remainder > 0:
text_mask = F.pad(text_mask, (0, remainder), value = False)
text_keep_mask &= text_mask
text_tokens = self.text_to_cond(text_encodings) text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where( text_tokens = torch.where(
text_keep_mask, text_keep_mask,
text_tokens, text_tokens,
self.null_text_embed[:, :text_tokens.shape[1]] self.null_text_embed
) )
# main conditioning tokens (c) # main conditioning tokens (c)
@@ -1434,8 +1451,8 @@ class Decoder(BaseGaussianDiffusion):
image_embed, _ = self.clip.embed_image(image) image_embed, _ = self.clip.embed_image(image)
return image_embed return image_embed
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.): def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
if predict_x_start: if predict_x_start:
x_recon = pred x_recon = pred
@@ -1449,16 +1466,16 @@ class Decoder(BaseGaussianDiffusion):
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False): def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start) model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
noise = noise_like(x.shape, device, repeat_noise) noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad() @torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1): def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
device = self.betas.device device = self.betas.device
b = shape[0] b = shape[0]
@@ -1471,6 +1488,7 @@ class Decoder(BaseGaussianDiffusion):
torch.full((b,), i, device = device, dtype = torch.long), torch.full((b,), i, device = device, dtype = torch.long),
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
text_mask = text_mask,
cond_scale = cond_scale, cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start predict_x_start = predict_x_start
@@ -1478,7 +1496,7 @@ class Decoder(BaseGaussianDiffusion):
return img return img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None): def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise) x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
@@ -1488,6 +1506,7 @@ class Decoder(BaseGaussianDiffusion):
times, times,
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
text_mask = text_mask,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
image_cond_drop_prob = self.image_cond_drop_prob, image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob, text_cond_drop_prob = self.text_cond_drop_prob,
@@ -1503,9 +1522,9 @@ class Decoder(BaseGaussianDiffusion):
def sample(self, image_embed, text = None, cond_scale = 1.): def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0] batch_size = image_embed.shape[0]
text_encodings = None text_encodings = text_mask = None
if exists(text): if exists(text):
_, text_encodings, _ = self.clip.embed_text(text) _, text_encodings, text_mask = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
@@ -1534,6 +1553,7 @@ class Decoder(BaseGaussianDiffusion):
shape, shape,
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
text_mask = text_mask,
cond_scale = cond_scale, cond_scale = cond_scale,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
lowres_cond_img = lowres_cond_img lowres_cond_img = lowres_cond_img
@@ -1571,9 +1591,9 @@ class Decoder(BaseGaussianDiffusion):
if not exists(image_embed): if not exists(image_embed):
image_embed, _ = self.clip.embed_image(image) image_embed, _ = self.clip.embed_image(image)
text_encodings = None text_encodings = text_mask = None
if exists(text) and not exists(text_encodings): if exists(text) and not exists(text_encodings):
_, text_encodings, _ = self.clip.embed_text(text) _, text_encodings, text_mask = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
@@ -1588,7 +1608,7 @@ class Decoder(BaseGaussianDiffusion):
if exists(lowres_cond_img): if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img) lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start) return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
# main class # main class

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.74', version = '0.0.75',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',