mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-15 10:06:13 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9faab59b23 |
@@ -1181,7 +1181,11 @@ class CrossAttention(nn.Module):
|
|||||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
||||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Linear(inner_dim, dim, bias = False),
|
||||||
|
LayerNorm(dim)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, context, mask = None):
|
def forward(self, x, context, mask = None):
|
||||||
b, n, device = *x.shape[:2], x.device
|
b, n, device = *x.shape[:2], x.device
|
||||||
|
|||||||
Reference in New Issue
Block a user