diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 769e6c5..6034aa3 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1731,7 +1731,10 @@ class Unet(nn.Module): ])) self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) - self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2) + + out_dim_in = dim + (channels if lowres_cond else 0) + + self.to_out = nn.Conv2d(out_dim_in, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2) zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it @@ -1951,6 +1954,10 @@ class Unet(nn.Module): x = torch.cat((x, r), dim = 1) x = self.final_resnet_block(x, t) + + if exists(lowres_cond_img): + x = torch.cat((x, lowres_cond_img), dim = 1) + return self.to_out(x) class LowresConditioner(nn.Module): diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index ce97d1d..f8ab8c2 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.23.10' +__version__ = '0.24.0'