Compare commits

..

2 Commits

Author SHA1 Message Date
Phil Wang
6651eafa93 one more residual, after seeing good results on unconditional generation locally 2022-06-16 11:18:02 -07:00
Phil Wang
e6bb75e5ab fix missing residual for highest resolution of the unet 2022-06-15 20:09:43 -07:00
2 changed files with 4 additions and 5 deletions

View File

@@ -1488,7 +1488,7 @@ class Unet(nn.Module):
]))
self.final_conv = nn.Sequential(
ResnetBlock(dim, dim, groups = resnet_groups[0]),
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, self.channels_out, 1)
)
@@ -1560,6 +1560,7 @@ class Unet(nn.Module):
# initial convolution
x = self.init_conv(x)
r = x.clone() # final residual
# time conditioning
@@ -1689,9 +1690,7 @@ class Unet(nn.Module):
x = upsample(x)
if len(hiddens) > 0:
x = torch.cat((x, hiddens.pop()), dim = 1)
x = torch.cat((x, r), dim = 1)
return self.final_conv(x)
class LowresConditioner(nn.Module):

View File

@@ -1 +1 @@
__version__ = '0.9.1'
__version__ = '0.9.2'