mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
dropouts in transformer, also prep for classifier free guidance in decoder
This commit is contained in:
@@ -44,12 +44,13 @@ def freeze_model_and_make_eval_(model):
|
|||||||
|
|
||||||
# diffusion prior
|
# diffusion prior
|
||||||
|
|
||||||
def FeedForward(dim, mult = 4):
|
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||||
inner_dim = int(mult * dim)
|
inner_dim = int(mult * dim)
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.LayerNorm(dim),
|
nn.LayerNorm(dim),
|
||||||
nn.Linear(dim, inner_dim, bias = False),
|
nn.Linear(dim, inner_dim, bias = False),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
nn.Linear(inner_dim, dim, bias = False)
|
nn.Linear(inner_dim, dim, bias = False)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -59,13 +60,16 @@ class Attention(nn.Module):
|
|||||||
*,
|
*,
|
||||||
dim,
|
dim,
|
||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8
|
heads = 8,
|
||||||
|
dropout = 0.
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
|
|
||||||
self.norm = nn.LayerNorm(dim)
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||||
|
|
||||||
@@ -106,7 +110,9 @@ class Transformer(nn.Module):
|
|||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
ff_mult = 4,
|
ff_mult = 4,
|
||||||
norm_out = False
|
norm_out = False,
|
||||||
|
attn_dropout = 0.,
|
||||||
|
ff_dropout = 0.
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# todo - bring in rotary embeddings or alibi
|
# todo - bring in rotary embeddings or alibi
|
||||||
@@ -114,8 +120,8 @@ class Transformer(nn.Module):
|
|||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
for _ in range(depth):
|
for _ in range(depth):
|
||||||
self.layers.append(nn.ModuleList([
|
self.layers.append(nn.ModuleList([
|
||||||
Attention(dim = dim, dim_head = dim_head, heads = heads),
|
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
|
||||||
FeedForward(dim = dim, mult = ff_mult)
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
self.norm = nn.LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
self.norm = nn.LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||||
@@ -168,7 +174,8 @@ class Decoder(nn.Module):
|
|||||||
*,
|
*,
|
||||||
image,
|
image,
|
||||||
image_embed,
|
image_embed,
|
||||||
text_embed = None # in paper, text embedding was optional for conditioning decoder
|
cond_drop_prob = 0.2, # for the classifier free guidance
|
||||||
|
text_embed = None # in paper, text embedding was optional for conditioning decoder
|
||||||
):
|
):
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user