nevermind, it could be working, but only when i stabilize it with the feedforward layer + tanh as proposed in vit-vqgan paper (which will be built into the repository later for the latent diffusion)

This commit is contained in:
Phil Wang
2022-04-26 12:43:31 -07:00
parent de0296106b
commit 4075d02139
2 changed files with 26 additions and 1 deletions

View File

@@ -44,7 +44,12 @@ class QueryAndAttend(nn.Module):
self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim * 2, 1, bias = False),
nn.Tanh(),
nn.Conv2d(dim * 2, dim, 1, bias = False)
)
def forward(self, x):
"""