one more residual, after seeing good results on unconditional generation locally

This commit is contained in:
Phil Wang
2022-06-16 11:18:02 -07:00
parent e6bb75e5ab
commit 6651eafa93
2 changed files with 4 additions and 2 deletions

View File

@@ -1488,7 +1488,7 @@ class Unet(nn.Module):
]))
self.final_conv = nn.Sequential(
ResnetBlock(dim, dim, groups = resnet_groups[0]),
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
nn.Conv2d(dim, self.channels_out, 1)
)
@@ -1560,6 +1560,7 @@ class Unet(nn.Module):
# initial convolution
x = self.init_conv(x)
r = x.clone() # final residual
# time conditioning
@@ -1689,6 +1690,7 @@ class Unet(nn.Module):
x = upsample(x)
x = torch.cat((x, r), dim = 1)
return self.final_conv(x)
class LowresConditioner(nn.Module):