diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 01ad5ff..3c6afb5 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1488,7 +1488,7 @@ class Unet(nn.Module): ])) self.final_conv = nn.Sequential( - ResnetBlock(dim, dim, groups = resnet_groups[0]), + ResnetBlock(dim * 2, dim, groups = resnet_groups[0]), nn.Conv2d(dim, self.channels_out, 1) ) @@ -1560,6 +1560,7 @@ class Unet(nn.Module): # initial convolution x = self.init_conv(x) + r = x.clone() # final residual # time conditioning @@ -1689,6 +1690,7 @@ class Unet(nn.Module): x = upsample(x) + x = torch.cat((x, r), dim = 1) return self.final_conv(x) class LowresConditioner(nn.Module): diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 8969d49..1f04780 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.9.1' +__version__ = '0.9.2'