mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
add PixelShuffleUpsample thanks to @MalumaDev and @marunine for running the experiment and verifyng absence of checkboard artifacts
This commit is contained in:
@@ -1223,16 +1223,35 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
|
|
||||||
def ConvTransposeUpsample(dim, dim_out = None):
|
class PixelShuffleUpsample(nn.Module):
|
||||||
dim_out = default(dim_out, dim)
|
"""
|
||||||
return nn.ConvTranspose2d(dim, dim_out, 4, 2, 1)
|
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
|
||||||
|
https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
|
||||||
|
"""
|
||||||
|
def __init__(self, dim, dim_out = None):
|
||||||
|
super().__init__()
|
||||||
|
dim_out = default(dim_out, dim)
|
||||||
|
conv = nn.Conv2d(dim, dim_out * 4, 1)
|
||||||
|
|
||||||
def NearestUpsample(dim, dim_out = None):
|
self.net = nn.Sequential(
|
||||||
dim_out = default(dim_out, dim)
|
conv,
|
||||||
return nn.Sequential(
|
nn.SiLU(),
|
||||||
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
nn.PixelShuffle(2)
|
||||||
nn.Conv2d(dim, dim_out, 3, padding = 1)
|
)
|
||||||
)
|
|
||||||
|
self.init_conv_(conv)
|
||||||
|
|
||||||
|
def init_conv_(self, conv):
|
||||||
|
o, i, h, w = conv.weight.shape
|
||||||
|
conv_weight = torch.empty(o // 4, i, h, w)
|
||||||
|
nn.init.kaiming_uniform_(conv_weight)
|
||||||
|
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
|
||||||
|
|
||||||
|
conv.weight.data.copy_(conv_weight)
|
||||||
|
nn.init.zeros_(conv.bias.data)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
def Downsample(dim, *, dim_out = None):
|
def Downsample(dim, *, dim_out = None):
|
||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
@@ -1496,7 +1515,7 @@ class Unet(nn.Module):
|
|||||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||||
memory_efficient = False,
|
memory_efficient = False,
|
||||||
scale_skip_connection = False,
|
scale_skip_connection = False,
|
||||||
nearest_upsample = False,
|
pixel_shuffle_upsample = True,
|
||||||
final_conv_kernel_size = 1,
|
final_conv_kernel_size = 1,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@@ -1610,7 +1629,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
# upsample klass
|
# upsample klass
|
||||||
|
|
||||||
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
|
upsample_klass = ConvTransposeUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
|
||||||
|
|
||||||
# give memory efficient unet an initial resnet block
|
# give memory efficient unet an initial resnet block
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.20.1'
|
__version__ = '0.21.0'
|
||||||
|
|||||||
Reference in New Issue
Block a user