diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index f35e9b3..003b9d3 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1251,6 +1251,14 @@ class DiffusionPrior(nn.Module): # decoder +def NearestUpsample(dim, dim_out = None): + dim_out = default(dim_out, dim) + + return nn.Sequential( + nn.Upsample(scale_factor = 2, mode = 'nearest'), + nn.Conv2d(dim, dim_out, 3, padding = 1) + ) + class PixelShuffleUpsample(nn.Module): """ code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts @@ -1657,7 +1665,7 @@ class Unet(nn.Module): # upsample klass - upsample_klass = ConvTransposeUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample + upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample # give memory efficient unet an initial resnet block diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index a123ffa..50690f9 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.23.5' +__version__ = '0.23.6'