diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 89ab4f4..9ad5025 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -922,6 +922,7 @@ class ConvNextBlock(nn.Module): dim_out, *, cond_dim = None, + time_cond_dim = None, mult = 2, norm = True ): @@ -940,6 +941,14 @@ class ConvNextBlock(nn.Module): ) ) + self.time_mlp = None + + if exists(time_cond_dim): + self.time_mlp = nn.Sequential( + nn.GELU(), + nn.Linear(time_cond_dim, dim) + ) + self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim) inner_dim = int(dim_out * mult) @@ -952,9 +961,13 @@ class ConvNextBlock(nn.Module): self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity() - def forward(self, x, cond = None): + def forward(self, x, cond = None, time = None): h = self.ds_conv(x) + if exists(time) and exists(self.time_mlp): + t = self.time_mlp(time) + h = rearrange(t, 'b c -> b c 1 1') + h + if exists(self.cross_attn): assert exists(cond) h = self.cross_attn(h, context = cond) + h @@ -1076,22 +1089,33 @@ class Unet(nn.Module): self.channels = channels init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis + init_dim = dim // 2 - dims = [init_channels, *map(lambda m: dim * m, dim_mults)] + self.init_conv = nn.Conv2d(init_channels, init_dim, 7, padding = 3) + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) # time, image embeddings, and optional text encoding cond_dim = default(cond_dim, dim) + time_cond_dim = dim * 4 - self.time_mlp = nn.Sequential( + self.to_time_hiddens = nn.Sequential( SinusoidalPosEmb(dim), - nn.Linear(dim, dim * 4), - nn.GELU(), - nn.Linear(dim * 4, cond_dim * num_time_tokens), + nn.Linear(dim, time_cond_dim), + nn.GELU() + ) + + self.to_time_tokens = nn.Sequential( + nn.Linear(time_cond_dim, cond_dim * num_time_tokens), Rearrange('b (r d) -> b r d', r = num_time_tokens) ) + self.to_time_cond = nn.Sequential( + nn.Linear(time_cond_dim, time_cond_dim) + ) + self.image_to_cond = nn.Sequential( nn.Linear(image_embed_dim, cond_dim * num_image_tokens), Rearrange('b (n d) -> b n d', n = num_image_tokens) @@ -1133,26 +1157,26 @@ class Unet(nn.Module): layer_cond_dim = cond_dim if not is_first else None self.downs.append(nn.ModuleList([ - ConvNextBlock(dim_in, dim_out, norm = ind != 0), + ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0), Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(), - ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim), + ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), Downsample(dim_out) if not is_last else nn.Identity() ])) mid_dim = dims[-1] - self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) + self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim) self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None - self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) + self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 2) layer_cond_dim = cond_dim if not is_last else None self.ups.append(nn.ModuleList([ - ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim), + ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(), - ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim), + ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim), Upsample(dim_in) ])) @@ -1214,9 +1238,16 @@ class Unet(nn.Module): if exists(lowres_cond_img): x = torch.cat((x, lowres_cond_img), dim = 1) + # initial convolution + + x = self.init_conv(x) + # time conditioning - time_tokens = self.time_mlp(time) + time_hiddens = self.to_time_hiddens(time) + + time_tokens = self.to_time_tokens(time_hiddens) + t = self.to_time_cond(time_hiddens) # conditional dropout @@ -1283,24 +1314,24 @@ class Unet(nn.Module): hiddens = [] for convnext, sparse_attn, convnext2, downsample in self.downs: - x = convnext(x, c) + x = convnext(x, c, t) x = sparse_attn(x) - x = convnext2(x, c) + x = convnext2(x, c, t) hiddens.append(x) x = downsample(x) - x = self.mid_block1(x, mid_c) + x = self.mid_block1(x, mid_c, t) if exists(self.mid_attn): x = self.mid_attn(x) - x = self.mid_block2(x, mid_c) + x = self.mid_block2(x, mid_c, t) for convnext, sparse_attn, convnext2, upsample in self.ups: x = torch.cat((x, hiddens.pop()), dim=1) - x = convnext(x, c) + x = convnext(x, c, t) x = sparse_attn(x) - x = convnext2(x, c) + x = convnext2(x, c, t) x = upsample(x) return self.final_conv(x) diff --git a/setup.py b/setup.py index 441e8b0..e3ed51a 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.82', + version = '0.0.84', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',