fix decoder needing separate conditional dropping probabilities for image embeddings and text encodings, thanks to @xiankgx !

This commit is contained in:
Phil Wang
2022-04-30 08:47:56 -07:00
parent 721a444686
commit f19c99ecb0
3 changed files with 26 additions and 15 deletions

View File

@@ -110,7 +110,8 @@ decoder = Decoder(
unet = unet, unet = unet,
clip = clip, clip = clip,
timesteps = 100, timesteps = 100,
cond_drop_prob = 0.2 image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda() ).cuda()
# mock images (get a lot of this) # mock images (get a lot of this)
@@ -229,7 +230,8 @@ decoder = Decoder(
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here) unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in) image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000, timesteps = 1000,
cond_drop_prob = 0.2 image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda() ).cuda()
# mock images (get a lot of this) # mock images (get a lot of this)
@@ -348,7 +350,8 @@ decoder = Decoder(
image_sizes = (128, 256), image_sizes = (128, 256),
clip = clip, clip = clip,
timesteps = 100, timesteps = 100,
cond_drop_prob = 0.2, image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda() ).cuda()
@@ -558,7 +561,8 @@ decoder = Decoder(
image_sizes = (128, 256), image_sizes = (128, 256),
clip = clip, clip = clip,
timesteps = 100, timesteps = 100,
cond_drop_prob = 0.2, image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda() ).cuda()
@@ -669,7 +673,8 @@ decoder = Decoder(
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here) unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
timesteps = 100, timesteps = 100,
cond_drop_prob = 0.2 image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5
).cuda() ).cuda()
# mock images (get a lot of this) # mock images (get a lot of this)

View File

@@ -1174,7 +1174,7 @@ class Unet(nn.Module):
if cond_scale == 1: if cond_scale == 1:
return logits return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale return null_logits + (logits - null_logits) * cond_scale
def forward( def forward(
@@ -1185,7 +1185,8 @@ class Unet(nn.Module):
image_embed, image_embed,
lowres_cond_img = None, lowres_cond_img = None,
text_encodings = None, text_encodings = None,
cond_drop_prob = 0., image_cond_drop_prob = 0.,
text_cond_drop_prob = 0.,
blur_sigma = None, blur_sigma = None,
blur_kernel_size = None blur_kernel_size = None
): ):
@@ -1204,8 +1205,10 @@ class Unet(nn.Module):
# conditional dropout # conditional dropout
keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device) image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
keep_mask = rearrange(keep_mask, 'b -> b 1 1') text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
# mask out image embedding depending on condition dropout # mask out image embedding depending on condition dropout
# for classifier free guidance # for classifier free guidance
@@ -1216,7 +1219,7 @@ class Unet(nn.Module):
image_tokens = self.image_to_cond(image_embed) image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where( image_tokens = torch.where(
keep_mask, image_keep_mask,
image_tokens, image_tokens,
self.null_image_embed self.null_image_embed
) )
@@ -1228,7 +1231,7 @@ class Unet(nn.Module):
if exists(text_encodings) and self.cond_on_text_encodings: if exists(text_encodings) and self.cond_on_text_encodings:
text_tokens = self.text_to_cond(text_encodings) text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where( text_tokens = torch.where(
keep_mask, text_keep_mask,
text_tokens, text_tokens,
self.null_text_embed[:, :text_tokens.shape[1]] self.null_text_embed[:, :text_tokens.shape[1]]
) )
@@ -1318,7 +1321,8 @@ class Decoder(BaseGaussianDiffusion):
clip, clip,
vae = tuple(), vae = tuple(),
timesteps = 1000, timesteps = 1000,
cond_drop_prob = 0.2, image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
loss_type = 'l1', loss_type = 'l1',
beta_schedule = 'cosine', beta_schedule = 'cosine',
predict_x_start = False, predict_x_start = False,
@@ -1402,7 +1406,8 @@ class Decoder(BaseGaussianDiffusion):
# classifier free guidance # classifier free guidance
self.cond_drop_prob = cond_drop_prob self.image_cond_drop_prob = image_cond_drop_prob
self.text_cond_drop_prob = text_cond_drop_prob
def get_unet(self, unet_number): def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets) assert 0 < unet_number <= len(self.unets)
@@ -1484,7 +1489,8 @@ class Decoder(BaseGaussianDiffusion):
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
cond_drop_prob = self.cond_drop_prob image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob,
) )
target = noise if not predict_x_start else x_start target = noise if not predict_x_start else x_start

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.73', version = '0.0.74',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',