diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index dca17b9..0bb86b5 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1343,6 +1343,7 @@ class Unet(nn.Module): cond_on_text_encodings = False, max_text_len = 256, cond_on_image_embeds = False, + add_image_embeds_to_time = True, # alerted by @mhh0318 to a phrase in the paper - "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and adding CLIP embeddings to the existing timestep embedding" init_dim = None, init_conv_kernel_size = 7, resnet_groups = 8, @@ -1396,11 +1397,16 @@ class Unet(nn.Module): nn.Linear(time_cond_dim, time_cond_dim) ) - self.image_to_cond = nn.Sequential( + self.image_to_tokens = nn.Sequential( nn.Linear(image_embed_dim, cond_dim * num_image_tokens), Rearrange('b (n d) -> b n d', n = num_image_tokens) ) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity() + self.to_image_hiddens = nn.Sequential( + nn.Linear(image_embed_dim, time_cond_dim), + nn.GELU() + ) if cond_on_image_embeds and add_image_embeds_to_time else None + self.norm_cond = nn.LayerNorm(cond_dim) self.norm_mid_cond = nn.LayerNorm(cond_dim) @@ -1558,6 +1564,13 @@ class Unet(nn.Module): time_tokens = self.to_time_tokens(time_hiddens) t = self.to_time_cond(time_hiddens) + # image embedding to be summed to time embedding + # discovered by @mhh0318 in the paper + + if exists(image_embed) and exists(self.to_image_hiddens): + image_hiddens = self.to_image_hiddens(image_embed) + t = t + image_hiddens + # conditional dropout image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device) @@ -1571,7 +1584,7 @@ class Unet(nn.Module): image_tokens = None if self.cond_on_image_embeds: - image_tokens = self.image_to_cond(image_embed) + image_tokens = self.image_to_tokens(image_embed) null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working image_tokens = torch.where( diff --git a/setup.py b/setup.py index 8165b26..fb2deb3 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.5.7', + version = '0.6.0', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',