mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 18:24:19 +01:00
start using RMSNorm, used in Gopher and AlphaCode, and as a way to go complete bias-less (purportedly more stable according to PaLM)
This commit is contained in:
@@ -44,10 +44,21 @@ def freeze_model_and_make_eval_(model):
|
|||||||
|
|
||||||
# diffusion prior
|
# diffusion prior
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim, eps = 1e-8):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim ** -0.5
|
||||||
|
self.eps = eps
|
||||||
|
self.g = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
|
||||||
|
return x / norm.clamp(min = self.eps) * self.g
|
||||||
|
|
||||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||||
inner_dim = int(mult * dim)
|
inner_dim = int(mult * dim)
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.LayerNorm(dim),
|
RMSNorm(dim),
|
||||||
nn.Linear(dim, inner_dim, bias = False),
|
nn.Linear(dim, inner_dim, bias = False),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
@@ -67,7 +78,7 @@ class Attention(nn.Module):
|
|||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
|
|
||||||
self.norm = nn.LayerNorm(dim)
|
self.norm = RMSNorm(dim)
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||||
@@ -124,7 +135,7 @@ class Transformer(nn.Module):
|
|||||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
self.norm = nn.LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
self.norm = RMSNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user