mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 10:14:19 +01:00
start using RMSNorm, used in Gopher and AlphaCode, and as a way to go complete bias-less (purportedly more stable according to PaLM)
This commit is contained in:
@@ -44,10 +44,21 @@ def freeze_model_and_make_eval_(model):
|
||||
|
||||
# diffusion prior
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-8):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
|
||||
return x / norm.clamp(min = self.eps) * self.g
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
inner_dim = int(mult * dim)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
RMSNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias = False),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
@@ -67,7 +78,7 @@ class Attention(nn.Module):
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.norm = RMSNorm(dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
@@ -124,7 +135,7 @@ class Transformer(nn.Module):
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
||||
]))
|
||||
|
||||
self.norm = nn.LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
self.norm = RMSNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user