mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
fix all issues with text encodings conditioning in the decoder, using null padding tokens technique from dalle v1
This commit is contained in:
@@ -1101,6 +1101,8 @@ class Unet(nn.Module):
|
||||
# for classifier free guidance
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||
|
||||
self.max_text_len = max_text_len
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||
|
||||
# attention related params
|
||||
@@ -1185,6 +1187,7 @@ class Unet(nn.Module):
|
||||
image_embed,
|
||||
lowres_cond_img = None,
|
||||
text_encodings = None,
|
||||
text_mask = None,
|
||||
image_cond_drop_prob = 0.,
|
||||
text_cond_drop_prob = 0.,
|
||||
blur_sigma = None,
|
||||
@@ -1229,11 +1232,25 @@ class Unet(nn.Module):
|
||||
text_tokens = None
|
||||
|
||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||
text_tokens = text_tokens[:, :self.max_text_len]
|
||||
|
||||
text_tokens_len = text_tokens.shape[1]
|
||||
remainder = self.max_text_len - text_tokens_len
|
||||
|
||||
if remainder > 0:
|
||||
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
|
||||
|
||||
if exists(text_mask):
|
||||
if remainder > 0:
|
||||
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
||||
|
||||
text_keep_mask &= text_mask
|
||||
|
||||
text_tokens = self.text_to_cond(text_encodings)
|
||||
text_tokens = torch.where(
|
||||
text_keep_mask,
|
||||
text_tokens,
|
||||
self.null_text_embed[:, :text_tokens.shape[1]]
|
||||
self.null_text_embed
|
||||
)
|
||||
|
||||
# main conditioning tokens (c)
|
||||
@@ -1434,8 +1451,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
return image_embed
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
if predict_x_start:
|
||||
x_recon = pred
|
||||
@@ -1449,16 +1466,16 @@ class Decoder(BaseGaussianDiffusion):
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
@@ -1471,6 +1488,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
torch.full((b,), i, device = device, dtype = torch.long),
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
text_mask = text_mask,
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
predict_x_start = predict_x_start
|
||||
@@ -1478,7 +1496,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
return img
|
||||
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
||||
@@ -1488,6 +1506,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
times,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
text_mask = text_mask,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
image_cond_drop_prob = self.image_cond_drop_prob,
|
||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||
@@ -1503,9 +1522,9 @@ class Decoder(BaseGaussianDiffusion):
|
||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||
batch_size = image_embed.shape[0]
|
||||
|
||||
text_encodings = None
|
||||
text_encodings = text_mask = None
|
||||
if exists(text):
|
||||
_, text_encodings, _ = self.clip.embed_text(text)
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
@@ -1534,6 +1553,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
shape,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
text_mask = text_mask,
|
||||
cond_scale = cond_scale,
|
||||
predict_x_start = predict_x_start,
|
||||
lowres_cond_img = lowres_cond_img
|
||||
@@ -1571,9 +1591,9 @@ class Decoder(BaseGaussianDiffusion):
|
||||
if not exists(image_embed):
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
|
||||
text_encodings = None
|
||||
text_encodings = text_mask = None
|
||||
if exists(text) and not exists(text_encodings):
|
||||
_, text_encodings, _ = self.clip.embed_text(text)
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
@@ -1588,7 +1608,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
||||
|
||||
# main class
|
||||
|
||||
|
||||
Reference in New Issue
Block a user