From b7f9607258497fbc2ae81ad71240f79486455402 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 15 Jun 2022 13:40:26 -0700 Subject: [PATCH] make memory efficient unet design from imagen toggle-able --- dalle2_pytorch/dalle2_pytorch.py | 19 ++++++++++++++----- dalle2_pytorch/version.py | 2 +- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 1ef9b3a..887c690 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1352,6 +1352,7 @@ class Unet(nn.Module): init_cross_embed_kernel_sizes = (3, 7, 15), cross_embed_downsample = False, cross_embed_downsample_kernel_sizes = (2, 4), + memory_efficient = False, **kwargs ): super().__init__() @@ -1462,10 +1463,11 @@ class Unet(nn.Module): layer_cond_dim = cond_dim if not is_first else None self.downs.append(nn.ModuleList([ - downsample_klass(dim_in, dim_out = dim_out), - ResnetBlock(dim_out, dim_out, time_cond_dim = time_cond_dim, groups = groups), + downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None, + ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups), Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(), nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]), + downsample_klass(dim_out) if not is_last and not memory_efficient else None ])) mid_dim = dims[-1] @@ -1474,7 +1476,9 @@ class Unet(nn.Module): self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) - for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks))): + up_in_out_slice = slice(1 if not memory_efficient else None, None) + + for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[up_in_out_slice]), reversed(resnet_groups), reversed(num_resnet_blocks))): is_last = ind >= (num_resolutions - 2) layer_cond_dim = cond_dim if not is_last else None @@ -1655,8 +1659,10 @@ class Unet(nn.Module): hiddens = [] - for downsample, init_block, sparse_attn, resnet_blocks in self.downs: - x = downsample(x) + for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs: + if exists(pre_downsample): + x = pre_downsample(x) + x = init_block(x, c, t) x = sparse_attn(x) @@ -1665,6 +1671,9 @@ class Unet(nn.Module): hiddens.append(x) + if exists(post_downsample): + x = post_downsample(x) + x = self.mid_block1(x, mid_c, t) if exists(self.mid_attn): diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 32a90a3..ef72cc0 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.8.0' +__version__ = '0.8.1'