mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
fix non pixel shuffle upsample
This commit is contained in:
@@ -1251,6 +1251,14 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
|
|
||||||
|
def NearestUpsample(dim, dim_out = None):
|
||||||
|
dim_out = default(dim_out, dim)
|
||||||
|
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
||||||
|
nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||||
|
)
|
||||||
|
|
||||||
class PixelShuffleUpsample(nn.Module):
|
class PixelShuffleUpsample(nn.Module):
|
||||||
"""
|
"""
|
||||||
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
|
code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
|
||||||
@@ -1657,7 +1665,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
# upsample klass
|
# upsample klass
|
||||||
|
|
||||||
upsample_klass = ConvTransposeUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
|
upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
|
||||||
|
|
||||||
# give memory efficient unet an initial resnet block
|
# give memory efficient unet an initial resnet block
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.23.5'
|
__version__ = '0.23.6'
|
||||||
|
|||||||
Reference in New Issue
Block a user