diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index d388ba3..5b76322 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1223,16 +1223,35 @@ class DiffusionPrior(nn.Module): # decoder -def ConvTransposeUpsample(dim, dim_out = None): - dim_out = default(dim_out, dim) - return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1) +class PixelShuffleUpsample(nn.Module): + """ + code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts + https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf + """ + def __init__(self, dim, dim_out = None): + super().__init__() + dim_out = default(dim_out, dim) + conv = nn.Conv2d(dim, dim_out * 4, 1) -def NearestUpsample(dim, dim_out = None): - dim_out = default(dim_out, dim) - return nn.Sequential( - nn.Upsample(scale_factor = 2, mode = 'nearest'), - nn.Conv2d(dim, dim_out, 3, padding = 1) - ) + self.net = nn.Sequential( + conv, + nn.SiLU(), + nn.PixelShuffle(2) + ) + + self.init_conv_(conv) + + def init_conv_(self, conv): + o, i, h, w = conv.weight.shape + conv_weight = torch.empty(o // 4, i, h, w) + nn.init.kaiming_uniform_(conv_weight) + conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...') + + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def forward(self, x): + return self.net(x) def Downsample(dim, *, dim_out = None): dim_out = default(dim_out, dim) @@ -1496,7 +1515,7 @@ class Unet(nn.Module): cross_embed_downsample_kernel_sizes = (2, 4), memory_efficient = False, scale_skip_connection = False, - nearest_upsample = False, + pixel_shuffle_upsample = True, final_conv_kernel_size = 1, **kwargs ): @@ -1610,7 +1629,7 @@ class Unet(nn.Module): # upsample klass - upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample + upsample_klass = ConvTransposeUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample # give memory efficient unet an initial resnet block diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index abadaef..e453371 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.20.1' +__version__ = '0.21.0'