mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
one more residual, after seeing good results on unconditional generation locally
This commit is contained in:
@@ -1488,7 +1488,7 @@ class Unet(nn.Module):
|
||||
]))
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
||||
ResnetBlock(dim * 2, dim, groups = resnet_groups[0]),
|
||||
nn.Conv2d(dim, self.channels_out, 1)
|
||||
)
|
||||
|
||||
@@ -1560,6 +1560,7 @@ class Unet(nn.Module):
|
||||
# initial convolution
|
||||
|
||||
x = self.init_conv(x)
|
||||
r = x.clone() # final residual
|
||||
|
||||
# time conditioning
|
||||
|
||||
@@ -1689,6 +1690,7 @@ class Unet(nn.Module):
|
||||
|
||||
x = upsample(x)
|
||||
|
||||
x = torch.cat((x, r), dim = 1)
|
||||
return self.final_conv(x)
|
||||
|
||||
class LowresConditioner(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user