mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
let the neural network peek at the low resolution conditioning one last time before making prediction, for upsamplers
This commit is contained in:
@@ -1731,7 +1731,10 @@ class Unet(nn.Module):
|
|||||||
]))
|
]))
|
||||||
|
|
||||||
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
||||||
self.to_out = nn.Conv2d(dim, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
|
||||||
|
out_dim_in = dim + (channels if lowres_cond else 0)
|
||||||
|
|
||||||
|
self.to_out = nn.Conv2d(out_dim_in, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
|
||||||
|
|
||||||
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
|
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
|
||||||
|
|
||||||
@@ -1951,6 +1954,10 @@ class Unet(nn.Module):
|
|||||||
x = torch.cat((x, r), dim = 1)
|
x = torch.cat((x, r), dim = 1)
|
||||||
|
|
||||||
x = self.final_resnet_block(x, t)
|
x = self.final_resnet_block(x, t)
|
||||||
|
|
||||||
|
if exists(lowres_cond_img):
|
||||||
|
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||||
|
|
||||||
return self.to_out(x)
|
return self.to_out(x)
|
||||||
|
|
||||||
class LowresConditioner(nn.Module):
|
class LowresConditioner(nn.Module):
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.23.10'
|
__version__ = '0.24.0'
|
||||||
|
|||||||
Reference in New Issue
Block a user