From 6651eafa937000ec72ef3b3892e22a8dc2997316 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 16 Jun 2022 11:18:02 -0700 Subject: [PATCH] one more residual, after seeing good results on unconditional generation locally --- dalle2_pytorch/dalle2_pytorch.py | 4 +++- dalle2_pytorch/version.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 01ad5ff..3c6afb5 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1488,7 +1488,7 @@ class Unet(nn.Module): ])) self.final_conv = nn.Sequential( - ResnetBlock(dim, dim, groups = resnet_groups[0]), + ResnetBlock(dim * 2, dim, groups = resnet_groups[0]), nn.Conv2d(dim, self.channels_out, 1) ) @@ -1560,6 +1560,7 @@ class Unet(nn.Module): # initial convolution x = self.init_conv(x) + r = x.clone() # final residual # time conditioning @@ -1689,6 +1690,7 @@ class Unet(nn.Module): x = upsample(x) + x = torch.cat((x, r), dim = 1) return self.final_conv(x) class LowresConditioner(nn.Module): diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 8969d49..1f04780 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.9.1' +__version__ = '0.9.2'