mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
optional projection out for prior network causal transformer
This commit is contained in:
@@ -350,7 +350,8 @@ class CausalTransformer(nn.Module):
|
||||
ff_mult = 4,
|
||||
norm_out = False,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.
|
||||
ff_dropout = 0.,
|
||||
final_proj = True
|
||||
):
|
||||
super().__init__()
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
@@ -363,6 +364,7 @@ class CausalTransformer(nn.Module):
|
||||
]))
|
||||
|
||||
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -377,7 +379,8 @@ class CausalTransformer(nn.Module):
|
||||
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
out = self.norm(x)
|
||||
return self.project_out(out)
|
||||
|
||||
class DiffusionPriorNetwork(nn.Module):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user