dropouts in transformer, also prep for classifier free guidance in decoder

This commit is contained in:
Phil Wang
2022-04-12 10:42:57 -07:00
parent 604765b563
commit 0a60818965

View File

@@ -44,12 +44,13 @@ def freeze_model_and_make_eval_(model):
# diffusion prior # diffusion prior
def FeedForward(dim, mult = 4): def FeedForward(dim, mult = 4, dropout = 0.):
inner_dim = int(mult * dim) inner_dim = int(mult * dim)
return nn.Sequential( return nn.Sequential(
nn.LayerNorm(dim), nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias = False), nn.Linear(dim, inner_dim, bias = False),
nn.GELU(), nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias = False) nn.Linear(inner_dim, dim, bias = False)
) )
@@ -59,13 +60,16 @@ class Attention(nn.Module):
*, *,
dim, dim,
dim_head = 64, dim_head = 64,
heads = 8 heads = 8,
dropout = 0.
): ):
super().__init__() super().__init__()
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
inner_dim = dim_head * heads inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim) self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False) self.to_out = nn.Linear(inner_dim, dim, bias = False)
@@ -106,7 +110,9 @@ class Transformer(nn.Module):
dim_head = 64, dim_head = 64,
heads = 8, heads = 8,
ff_mult = 4, ff_mult = 4,
norm_out = False norm_out = False,
attn_dropout = 0.,
ff_dropout = 0.
): ):
super().__init__() super().__init__()
# todo - bring in rotary embeddings or alibi # todo - bring in rotary embeddings or alibi
@@ -114,8 +120,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads), Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
FeedForward(dim = dim, mult = ff_mult) FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
])) ]))
self.norm = nn.LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options self.norm = nn.LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
@@ -168,7 +174,8 @@ class Decoder(nn.Module):
*, *,
image, image,
image_embed, image_embed,
text_embed = None # in paper, text embedding was optional for conditioning decoder cond_drop_prob = 0.2, # for the classifier free guidance
text_embed = None # in paper, text embedding was optional for conditioning decoder
): ):
return image return image