add PixelShuffleUpsample thanks to @MalumaDev and @marunine for running the experiment and verifyng absence of checkboard artifacts

This commit is contained in:
Phil Wang
2022-07-11 16:07:23 -07:00
parent bdd62c24b3
commit 1d9ef99288
2 changed files with 31 additions and 12 deletions

View File

@@ -1223,17 +1223,36 @@ class DiffusionPrior(nn.Module):
# decoder # decoder
def ConvTransposeUpsample(dim, dim_out = None): class PixelShuffleUpsample(nn.Module):
"""
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
"""
def __init__(self, dim, dim_out = None):
super().__init__()
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1) conv = nn.Conv2d(dim, dim_out * 4, 1)
def NearestUpsample(dim, dim_out = None): self.net = nn.Sequential(
dim_out = default(dim_out, dim) conv,
return nn.Sequential( nn.SiLU(),
nn.Upsample(scale_factor = 2, mode = 'nearest'), nn.PixelShuffle(2)
nn.Conv2d(dim, dim_out, 3, padding = 1)
) )
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
return self.net(x)
def Downsample(dim, *, dim_out = None): def Downsample(dim, *, dim_out = None):
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
return nn.Conv2d(dim, dim_out, 4, 2, 1) return nn.Conv2d(dim, dim_out, 4, 2, 1)
@@ -1496,7 +1515,7 @@ class Unet(nn.Module):
cross_embed_downsample_kernel_sizes = (2, 4), cross_embed_downsample_kernel_sizes = (2, 4),
memory_efficient = False, memory_efficient = False,
scale_skip_connection = False, scale_skip_connection = False,
nearest_upsample = False, pixel_shuffle_upsample = True,
final_conv_kernel_size = 1, final_conv_kernel_size = 1,
**kwargs **kwargs
): ):
@@ -1610,7 +1629,7 @@ class Unet(nn.Module):
# upsample klass # upsample klass
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample upsample_klass = ConvTransposeUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
# give memory efficient unet an initial resnet block # give memory efficient unet an initial resnet block

View File

@@ -1 +1 @@
__version__ = '0.20.1' __version__ = '0.21.0'