From a2ee3fa3cc3650a8d226e908df8854d0a3ca321a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 15 Jul 2022 17:29:10 -0700 Subject: [PATCH] offer way to turn off initial cross embed convolutional module, for debugging upsampler artifacts --- dalle2_pytorch/dalle2_pytorch.py | 3 ++- dalle2_pytorch/train_configs.py | 1 + dalle2_pytorch/version.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 4626080..ff8fc77 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1550,6 +1550,7 @@ class Unet(nn.Module): init_conv_kernel_size = 7, resnet_groups = 8, num_resnet_blocks = 2, + init_cross_embed = True, init_cross_embed_kernel_sizes = (3, 7, 15), cross_embed_downsample = False, cross_embed_downsample_kernel_sizes = (2, 4), @@ -1578,7 +1579,7 @@ class Unet(nn.Module): init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis init_dim = default(init_dim, dim) - self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) + self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index fc24282..dd30e46 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -225,6 +225,7 @@ class UnetConfig(BaseModel): self_attn: ListOrTuple(int) attn_dim_head: int = 32 attn_heads: int = 16 + init_cross_embed: bool = True class Config: extra = "allow" diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index d1d123f..c6d6a56 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.24.2' +__version__ = '0.24.3'