diff --git a/README.md b/README.md index ea96b25..ecbc861 100644 --- a/README.md +++ b/README.md @@ -897,4 +897,14 @@ Once built, images will be saved to the same directory the command is invoked } ``` +```bibtex +@article{Shleifer2021NormFormerIT, + title = {NormFormer: Improved Transformer Pretraining with Extra Normalization}, + author = {Sam Shleifer and Jason Weston and Myle Ott}, + journal = {ArXiv}, + year = {2021}, + volume = {abs/2110.09456} +} +``` + *Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 17d7924..2ab0e7e 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -499,7 +499,12 @@ class SwiGLU(nn.Module): x, gate = x.chunk(2, dim = -1) return x * F.silu(gate) -def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False): +def FeedForward( + dim, + mult = 4, + dropout = 0., + post_activation_norm = False +): """ post-activation norm https://arxiv.org/abs/2110.09456 """ inner_dim = int(mult * dim) @@ -522,7 +527,8 @@ class Attention(nn.Module): dim_head = 64, heads = 8, dropout = 0., - causal = False + causal = False, + post_norm = False ): super().__init__() self.scale = dim_head ** -0.5 @@ -537,7 +543,11 @@ class Attention(nn.Module): self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) - self.to_out = nn.Linear(inner_dim, dim, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim, bias = False), + LayerNorm(dim) if post_norm else nn.Identity() + ) def forward(self, x, mask = None, attn_bias = None): b, n, device = *x.shape[:2], x.device @@ -602,7 +612,8 @@ class CausalTransformer(nn.Module): norm_out = True, attn_dropout = 0., ff_dropout = 0., - final_proj = True + final_proj = True, + normformer = False ): super().__init__() self.rel_pos_bias = RelPosBias(heads = heads) @@ -610,8 +621,8 @@ class CausalTransformer(nn.Module): self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout), - FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) + Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer), + FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer) ])) self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options diff --git a/setup.py b/setup.py index 68367dc..0b6246a 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.92', + version = '0.0.93', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',