From 25d980ebbf1e22ce8396cdec400e22e83f754176 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 12 Apr 2022 17:27:39 -0700 Subject: [PATCH] complete naive conditioning of unet with image embedding, with ability to dropout for classifier free guidance --- dalle2_pytorch/dalle2_pytorch.py | 65 +++++++++++++++++++------------- 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 3c2aaf0..20d4007 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -231,7 +231,7 @@ class DiffusionPriorNetwork(nn.Module): # classifier free guidance - cond_prob_mask = prob_mask_like(batch_size, cond_prob_drop, device = device) + cond_prob_mask = prob_mask_like((batch_size,), cond_prob_drop, device = device) mask &= rearrange(cond_prob_mask, 'b -> b 1') # attend @@ -290,7 +290,7 @@ class ConvNextBlock(nn.Module): dim, dim_out, *, - time_emb_dim = None, + cond_dim = None, mult = 2, norm = True ): @@ -299,8 +299,8 @@ class ConvNextBlock(nn.Module): self.mlp = nn.Sequential( nn.GELU(), - nn.Linear(time_emb_dim, dim) - ) if exists(time_emb_dim) else None + nn.Linear(cond_dim, dim) + ) if exists(cond_dim) else None self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim) @@ -314,12 +314,12 @@ class ConvNextBlock(nn.Module): self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity() - def forward(self, x, time_emb = None): + def forward(self, x, cond = None): h = self.ds_conv(x) if exists(self.mlp): - assert exists(time_emb) - condition = self.mlp(time_emb) + assert exists(cond) + condition = self.mlp(cond) h = h + rearrange(condition, 'b c -> b c 1 1') h = self.net(h) @@ -331,10 +331,10 @@ class Unet(nn.Module): dim, *, image_embed_dim, + time_dim = None, out_dim = None, dim_mults=(1, 2, 4, 8), channels = 3, - with_time_emb = True ): super().__init__() self.channels = channels @@ -342,17 +342,18 @@ class Unet(nn.Module): dims = [channels, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) - if with_time_emb: - time_dim = dim - self.time_mlp = nn.Sequential( - SinusoidalPosEmb(dim), - nn.Linear(dim, dim * 4), - nn.GELU(), - nn.Linear(dim * 4, dim) - ) - else: - time_dim = None - self.time_mlp = None + time_dim = default(time_dim, dim) + + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, dim * 4), + nn.GELU(), + nn.Linear(dim * 4, dim) + ) + + self.null_image_embed = nn.Parameter(torch.randn(image_embed_dim)) + + cond_dim = time_dim + image_embed_dim self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) @@ -362,20 +363,20 @@ class Unet(nn.Module): is_last = ind >= (num_resolutions - 1) self.downs.append(nn.ModuleList([ - ConvNextBlock(dim_in, dim_out, time_emb_dim = time_dim, norm = ind != 0), - ConvNextBlock(dim_out, dim_out, time_emb_dim = time_dim), + ConvNextBlock(dim_in, dim_out, cond_dim = cond_dim, norm = ind != 0), + ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim), Downsample(dim_out) if not is_last else nn.Identity() ])) mid_dim = dims[-1] - self.mid_block = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) self.ups.append(nn.ModuleList([ - ConvNextBlock(dim_out * 2, dim_in, time_emb_dim = time_dim), - ConvNextBlock(dim_in, dim_in, time_emb_dim = time_dim), + ConvNextBlock(dim_out * 2, dim_in, cond_dim = cond_dim), + ConvNextBlock(dim_in, dim_in, cond_dim = cond_dim), Upsample(dim_in) if not is_last else nn.Identity() ])) @@ -408,10 +409,20 @@ class Unet(nn.Module): text_encodings = None, cond_prob_drop = 0. ): - batch_size, device = image_embed.shape[0], image_embed.device - t = self.time_mlp(time) if exists(self.time_mlp) else None + t = self.time_mlp(time) - cond_prob_mask = prob_mask_like(batch_size, cond_prob_drop, device = device) + cond_prob_mask = prob_mask_like((batch_size,), cond_prob_drop, device = device) + + # mask out image embedding depending on condition dropout + # for classifier free guidance + + image_embed = torch.where( + rearrange(cond_prob_mask, 'b -> b 1'), + image_embed, + rearrange(self.null_image_embed, 'd -> 1 d') + ) + + cond = torch.cat((t, image_embed), dim = -1) hiddens = []