mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
nevermind, it could be working, but only when i stabilize it with the feedforward layer + tanh as proposed in vit-vqgan paper (which will be built into the repository later for the latent diffusion)
This commit is contained in:
@@ -44,7 +44,12 @@ class QueryAndAttend(nn.Module):
|
||||
|
||||
self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
|
||||
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim * 2, 1, bias = False),
|
||||
nn.Tanh(),
|
||||
nn.Conv2d(dim * 2, dim, 1, bias = False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user