diff --git a/README.md b/README.md index b0f16a6..b0bc289 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,8 @@ decoder = Decoder( unet = unet, clip = clip, timesteps = 100, - cond_drop_prob = 0.2 + image_cond_drop_prob = 0.1, + text_cond_drop_prob = 0.5 ).cuda() # mock images (get a lot of this) @@ -229,7 +230,8 @@ decoder = Decoder( unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here) image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in) timesteps = 1000, - cond_drop_prob = 0.2 + image_cond_drop_prob = 0.1, + text_cond_drop_prob = 0.5 ).cuda() # mock images (get a lot of this) @@ -348,7 +350,8 @@ decoder = Decoder( image_sizes = (128, 256), clip = clip, timesteps = 100, - cond_drop_prob = 0.2, + image_cond_drop_prob = 0.1, + text_cond_drop_prob = 0.5, condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling ).cuda() @@ -558,7 +561,8 @@ decoder = Decoder( image_sizes = (128, 256), clip = clip, timesteps = 100, - cond_drop_prob = 0.2, + image_cond_drop_prob = 0.1, + text_cond_drop_prob = 0.5, condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling ).cuda() @@ -669,7 +673,8 @@ decoder = Decoder( unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here) image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third timesteps = 100, - cond_drop_prob = 0.2 + image_cond_drop_prob = 0.1, + text_cond_drop_prob = 0.5 ).cuda() # mock images (get a lot of this) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index e5d4cec..48e8ef8 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1174,7 +1174,7 @@ class Unet(nn.Module): if cond_scale == 1: return logits - null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) + null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( @@ -1185,7 +1185,8 @@ class Unet(nn.Module): image_embed, lowres_cond_img = None, text_encodings = None, - cond_drop_prob = 0., + image_cond_drop_prob = 0., + text_cond_drop_prob = 0., blur_sigma = None, blur_kernel_size = None ): @@ -1204,8 +1205,10 @@ class Unet(nn.Module): # conditional dropout - keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device) - keep_mask = rearrange(keep_mask, 'b -> b 1 1') + image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device) + text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device) + + image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1') # mask out image embedding depending on condition dropout # for classifier free guidance @@ -1216,7 +1219,7 @@ class Unet(nn.Module): image_tokens = self.image_to_cond(image_embed) image_tokens = torch.where( - keep_mask, + image_keep_mask, image_tokens, self.null_image_embed ) @@ -1228,7 +1231,7 @@ class Unet(nn.Module): if exists(text_encodings) and self.cond_on_text_encodings: text_tokens = self.text_to_cond(text_encodings) text_tokens = torch.where( - keep_mask, + text_keep_mask, text_tokens, self.null_text_embed[:, :text_tokens.shape[1]] ) @@ -1318,7 +1321,8 @@ class Decoder(BaseGaussianDiffusion): clip, vae = tuple(), timesteps = 1000, - cond_drop_prob = 0.2, + image_cond_drop_prob = 0.1, + text_cond_drop_prob = 0.5, loss_type = 'l1', beta_schedule = 'cosine', predict_x_start = False, @@ -1402,7 +1406,8 @@ class Decoder(BaseGaussianDiffusion): # classifier free guidance - self.cond_drop_prob = cond_drop_prob + self.image_cond_drop_prob = image_cond_drop_prob + self.text_cond_drop_prob = text_cond_drop_prob def get_unet(self, unet_number): assert 0 < unet_number <= len(self.unets) @@ -1484,7 +1489,8 @@ class Decoder(BaseGaussianDiffusion): image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, - cond_drop_prob = self.cond_drop_prob + image_cond_drop_prob = self.image_cond_drop_prob, + text_cond_drop_prob = self.text_cond_drop_prob, ) target = noise if not predict_x_start else x_start diff --git a/setup.py b/setup.py index 8687aa3..51c7ac3 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.73', + version = '0.0.74', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',