fix non pixel shuffle upsample

This commit is contained in:
Phil Wang
2022-07-13 10:16:02 -07:00
parent 95a512cb65
commit cc0f7a935c
2 changed files with 10 additions and 2 deletions

View File

@@ -1251,6 +1251,14 @@ class DiffusionPrior(nn.Module):
# decoder # decoder
def NearestUpsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, dim_out, 3, padding = 1)
)
class PixelShuffleUpsample(nn.Module): class PixelShuffleUpsample(nn.Module):
""" """
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
@@ -1657,7 +1665,7 @@ class Unet(nn.Module):
# upsample klass # upsample klass
upsample_klass = ConvTransposeUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
# give memory efficient unet an initial resnet block # give memory efficient unet an initial resnet block

View File

@@ -1 +1 @@
__version__ = '0.23.5' __version__ = '0.23.6'