fix decoder needing separate conditional dropping probabilities for image embeddings and text encodings, thanks to @xiankgx !

This commit is contained in:
Phil Wang
2022-04-30 08:47:56 -07:00
parent 721a444686
commit f19c99ecb0
3 changed files with 26 additions and 15 deletions

View File

@@ -1174,7 +1174,7 @@ class Unet(nn.Module):
if cond_scale == 1:
return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
@@ -1185,7 +1185,8 @@ class Unet(nn.Module):
image_embed,
lowres_cond_img = None,
text_encodings = None,
cond_drop_prob = 0.,
image_cond_drop_prob = 0.,
text_cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
):
@@ -1204,8 +1205,10 @@ class Unet(nn.Module):
# conditional dropout
keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
keep_mask = rearrange(keep_mask, 'b -> b 1 1')
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
# mask out image embedding depending on condition dropout
# for classifier free guidance
@@ -1216,7 +1219,7 @@ class Unet(nn.Module):
image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where(
keep_mask,
image_keep_mask,
image_tokens,
self.null_image_embed
)
@@ -1228,7 +1231,7 @@ class Unet(nn.Module):
if exists(text_encodings) and self.cond_on_text_encodings:
text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where(
keep_mask,
text_keep_mask,
text_tokens,
self.null_text_embed[:, :text_tokens.shape[1]]
)
@@ -1318,7 +1321,8 @@ class Decoder(BaseGaussianDiffusion):
clip,
vae = tuple(),
timesteps = 1000,
cond_drop_prob = 0.2,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
loss_type = 'l1',
beta_schedule = 'cosine',
predict_x_start = False,
@@ -1402,7 +1406,8 @@ class Decoder(BaseGaussianDiffusion):
# classifier free guidance
self.cond_drop_prob = cond_drop_prob
self.image_cond_drop_prob = image_cond_drop_prob
self.text_cond_drop_prob = text_cond_drop_prob
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
@@ -1484,7 +1489,8 @@ class Decoder(BaseGaussianDiffusion):
image_embed = image_embed,
text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img,
cond_drop_prob = self.cond_drop_prob
image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob,
)
target = noise if not predict_x_start else x_start