From 6647050c335987d8a3474bb210f0929851fb6dd9 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 15 Jun 2022 18:01:12 -0700 Subject: [PATCH] fix missing resisidual for highest resolution of the unet --- dalle2_pytorch/dalle2_pytorch.py | 9 +++++++-- dalle2_pytorch/version.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 887c690..6569810 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1489,8 +1489,10 @@ class Unet(nn.Module): Upsample(dim_in) ])) + final_dim_in = dim * (1 if memory_efficient else 2) + self.final_conv = nn.Sequential( - ResnetBlock(dim, dim, groups = resnet_groups[0]), + ResnetBlock(final_dim_in, dim, groups = resnet_groups[0]), nn.Conv2d(dim, self.channels_out, 1) ) @@ -1682,7 +1684,7 @@ class Unet(nn.Module): x = self.mid_block2(x, mid_c, t) for init_block, sparse_attn, resnet_blocks, upsample in self.ups: - x = torch.cat((x, hiddens.pop()), dim=1) + x = torch.cat((x, hiddens.pop()), dim = 1) x = init_block(x, c, t) x = sparse_attn(x) @@ -1691,6 +1693,9 @@ class Unet(nn.Module): x = upsample(x) + if len(hiddens) > 0: + x = torch.cat((x, hiddens.pop()), dim = 1) + return self.final_conv(x) class LowresConditioner(nn.Module): diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index ef72cc0..e4e49b3 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.8.1' +__version__ = '0.9.0'