Compare commits

..

1 Commits

Author SHA1 Message Date
Phil Wang
57f1ddf9d2 fix missing resisidual for highest resolution of the unet 2022-06-15 19:11:58 -07:00
2 changed files with 5 additions and 4 deletions

View File

@@ -1488,7 +1488,7 @@ class Unet(nn.Module):
]))
self.final_conv = nn.Sequential(
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
ResnetBlock(dim, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, self.channels_out, 1)
)
@@ -1560,7 +1560,6 @@ class Unet(nn.Module):
# initial convolution
x = self.init_conv(x)
r = x.clone() # final residual
# time conditioning
@@ -1690,7 +1689,9 @@ class Unet(nn.Module):
x = upsample(x)
x = torch.cat((x, r), dim = 1)
if len(hiddens) > 0:
x = torch.cat((x, hiddens.pop()), dim = 1)
return self.final_conv(x)
class LowresConditioner(nn.Module):

View File

@@ -1 +1 @@
__version__ = '0.9.2'
__version__ = '0.9.1'