mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 18:24:19 +01:00
fix missing resisidual for highest resolution of the unet
This commit is contained in:
@@ -1489,8 +1489,10 @@ class Unet(nn.Module):
|
|||||||
Upsample(dim_in)
|
Upsample(dim_in)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
|
final_dim_in = dim * (1 if memory_efficient else 2)
|
||||||
|
|
||||||
self.final_conv = nn.Sequential(
|
self.final_conv = nn.Sequential(
|
||||||
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
ResnetBlock(final_dim_in, dim, groups = resnet_groups[0]),
|
||||||
nn.Conv2d(dim, self.channels_out, 1)
|
nn.Conv2d(dim, self.channels_out, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1682,7 +1684,7 @@ class Unet(nn.Module):
|
|||||||
x = self.mid_block2(x, mid_c, t)
|
x = self.mid_block2(x, mid_c, t)
|
||||||
|
|
||||||
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
x = torch.cat((x, hiddens.pop()), dim = 1)
|
||||||
x = init_block(x, c, t)
|
x = init_block(x, c, t)
|
||||||
x = sparse_attn(x)
|
x = sparse_attn(x)
|
||||||
|
|
||||||
@@ -1691,6 +1693,9 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
x = upsample(x)
|
x = upsample(x)
|
||||||
|
|
||||||
|
if len(hiddens) > 0:
|
||||||
|
x = torch.cat((x, hiddens.pop()), dim = 1)
|
||||||
|
|
||||||
return self.final_conv(x)
|
return self.final_conv(x)
|
||||||
|
|
||||||
class LowresConditioner(nn.Module):
|
class LowresConditioner(nn.Module):
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.8.1'
|
__version__ = '0.9.0'
|
||||||
|
|||||||
Reference in New Issue
Block a user