revert back to old upsampling, paper does not work

This commit is contained in:
Phil Wang
2022-04-26 07:39:04 -07:00
parent 45262a4bb7
commit 0b28ee0d01
3 changed files with 2 additions and 12 deletions

View File

@@ -693,7 +693,7 @@ class DiffusionPrior(nn.Module):
# decoder
def Upsample(dim):
return QueryAttnUpsample(dim)
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)

View File

@@ -378,7 +378,7 @@ class VQGanVAE(nn.Module):
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(QueryAttnUpsample(dim_out), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
if layer_use_attn:
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))