mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
dropouts in transformer, also prep for classifier free guidance in decoder
This commit is contained in:
@@ -44,12 +44,13 @@ def freeze_model_and_make_eval_(model):
|
||||
|
||||
# diffusion prior
|
||||
|
||||
def FeedForward(dim, mult = 4):
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
inner_dim = int(mult * dim)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias = False),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim, bias = False)
|
||||
)
|
||||
|
||||
@@ -59,13 +60,16 @@ class Attention(nn.Module):
|
||||
*,
|
||||
dim,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
heads = 8,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
@@ -106,7 +110,9 @@ class Transformer(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
ff_mult = 4,
|
||||
norm_out = False
|
||||
norm_out = False,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
# todo - bring in rotary embeddings or alibi
|
||||
@@ -114,8 +120,8 @@ class Transformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim = dim, dim_head = dim_head, heads = heads),
|
||||
FeedForward(dim = dim, mult = ff_mult)
|
||||
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
||||
]))
|
||||
|
||||
self.norm = nn.LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
@@ -168,7 +174,8 @@ class Decoder(nn.Module):
|
||||
*,
|
||||
image,
|
||||
image_embed,
|
||||
text_embed = None # in paper, text embedding was optional for conditioning decoder
|
||||
cond_drop_prob = 0.2, # for the classifier free guidance
|
||||
text_embed = None # in paper, text embedding was optional for conditioning decoder
|
||||
):
|
||||
return image
|
||||
|
||||
|
||||
Reference in New Issue
Block a user