From 522f42f5827f9f939efd52fdd6f69e055fb354da Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 12 Apr 2022 10:45:03 -0700 Subject: [PATCH] start using RMSNorm, used in Gopher and AlphaCode, and as a way to go complete bias-less (purportedly more stable according to PaLM) --- dalle2_pytorch/dalle2_pytorch.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 12578a7..1e06b55 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -44,10 +44,21 @@ def freeze_model_and_make_eval_(model): # diffusion prior +class RMSNorm(nn.Module): + def __init__(self, dim, eps = 1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim = -1, keepdim = True) * self.scale + return x / norm.clamp(min = self.eps) * self.g + def FeedForward(dim, mult = 4, dropout = 0.): inner_dim = int(mult * dim) return nn.Sequential( - nn.LayerNorm(dim), + RMSNorm(dim), nn.Linear(dim, inner_dim, bias = False), nn.GELU(), nn.Dropout(dropout), @@ -67,7 +78,7 @@ class Attention(nn.Module): self.scale = dim_head ** -0.5 inner_dim = dim_head * heads - self.norm = nn.LayerNorm(dim) + self.norm = RMSNorm(dim) self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) @@ -124,7 +135,7 @@ class Transformer(nn.Module): FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) ])) - self.norm = nn.LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options + self.norm = RMSNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options def forward( self,