diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index b9aef83..12578a7 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -44,12 +44,13 @@ def freeze_model_and_make_eval_(model): # diffusion prior -def FeedForward(dim, mult = 4): +def FeedForward(dim, mult = 4, dropout = 0.): inner_dim = int(mult * dim) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim, bias = False), nn.GELU(), + nn.Dropout(dropout), nn.Linear(inner_dim, dim, bias = False) ) @@ -59,13 +60,16 @@ class Attention(nn.Module): *, dim, dim_head = 64, - heads = 8 + heads = 8, + dropout = 0. ): super().__init__() self.scale = dim_head ** -0.5 inner_dim = dim_head * heads self.norm = nn.LayerNorm(dim) + self.dropout = nn.Dropout(dropout) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Linear(inner_dim, dim, bias = False) @@ -106,7 +110,9 @@ class Transformer(nn.Module): dim_head = 64, heads = 8, ff_mult = 4, - norm_out = False + norm_out = False, + attn_dropout = 0., + ff_dropout = 0. ): super().__init__() # todo - bring in rotary embeddings or alibi @@ -114,8 +120,8 @@ class Transformer(nn.Module): self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - Attention(dim = dim, dim_head = dim_head, heads = heads), - FeedForward(dim = dim, mult = ff_mult) + Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout), + FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) ])) self.norm = nn.LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options @@ -168,7 +174,8 @@ class Decoder(nn.Module): *, image, image_embed, - text_embed = None # in paper, text embedding was optional for conditioning decoder + cond_drop_prob = 0.2, # for the classifier free guidance + text_embed = None # in paper, text embedding was optional for conditioning decoder ): return image