only use the stable layernorm for final output norm in transformer

This commit is contained in:
Phil Wang
2022-07-13 07:56:25 -07:00
parent 544cdd0b29
commit 79e2a3bc77
2 changed files with 13 additions and 7 deletions

View File

@@ -527,25 +527,31 @@ class NoiseScheduler(nn.Module):
# diffusion prior # diffusion prior
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5): def __init__(self, dim, eps = 1e-5, stable = False):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.stable = stable
self.g = nn.Parameter(torch.ones(dim)) self.g = nn.Parameter(torch.ones(dim))
def forward(self, x): def forward(self, x):
if self.stable:
x = x / x.amax(dim = -1, keepdim = True).detach() x = x / x.amax(dim = -1, keepdim = True).detach()
var = torch.var(x, dim = -1, unbiased = False, keepdim = True) var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True) mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g return (x - mean) * (var + self.eps).rsqrt() * self.g
class ChanLayerNorm(nn.Module): class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5): def __init__(self, dim, eps = 1e-5, stable = False):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.stable = stable
self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x): def forward(self, x):
if self.stable:
x = x / x.amax(dim = 1, keepdim = True).detach() x = x / x.amax(dim = 1, keepdim = True).detach()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True) var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True) mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g return (x - mean) * (var + self.eps).rsqrt() * self.g
@@ -669,7 +675,7 @@ class Attention(nn.Module):
dropout = 0., dropout = 0.,
causal = False, causal = False,
rotary_emb = None, rotary_emb = None,
pb_relax_alpha = 32 ** 2 pb_relax_alpha = 128
): ):
super().__init__() super().__init__()
self.pb_relax_alpha = pb_relax_alpha self.pb_relax_alpha = pb_relax_alpha
@@ -782,7 +788,7 @@ class CausalTransformer(nn.Module):
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer) FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
])) ]))
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity() self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
def forward(self, x): def forward(self, x):

View File

@@ -1 +1 @@
__version__ = '0.23.2' __version__ = '0.23.3'