start using RMSNorm, used in Gopher and AlphaCode, and as a way to go complete bias-less (purportedly more stable according to PaLM)

This commit is contained in:
Phil Wang
2022-04-12 10:45:03 -07:00
parent 0a60818965
commit 522f42f582

View File

@@ -44,10 +44,21 @@ def freeze_model_and_make_eval_(model):
# diffusion prior
class RMSNorm(nn.Module):
def __init__(self, dim, eps = 1e-8):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
return x / norm.clamp(min = self.eps) * self.g
def FeedForward(dim, mult = 4, dropout = 0.):
inner_dim = int(mult * dim)
return nn.Sequential(
nn.LayerNorm(dim),
RMSNorm(dim),
nn.Linear(dim, inner_dim, bias = False),
nn.GELU(),
nn.Dropout(dropout),
@@ -67,7 +78,7 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)
self.norm = RMSNorm(dim)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
@@ -124,7 +135,7 @@ class Transformer(nn.Module):
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))
self.norm = nn.LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.norm = RMSNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
def forward(
self,