move epsilon inside of square root for further stability in rmsnorm

improvise and use rmsnorm in convnext blocks too
This commit is contained in:
Phil Wang
2022-04-12 11:18:36 -07:00
parent cf22affcbb
commit 83aabd42ca
2 changed files with 18 additions and 6 deletions

View File

@@ -71,3 +71,14 @@ Todo
year = {2022}
}
```
```bibtex
@misc{zhang2019root,
title = {Root Mean Square Layer Normalization},
author = {Biao Zhang and Rico Sennrich},
year = {2019},
eprint = {1910.07467},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```

View File

@@ -45,15 +45,16 @@ def freeze_model_and_make_eval_(model):
# diffusion prior
class RMSNorm(nn.Module):
def __init__(self, dim, eps = 1e-8):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
return x / norm.clamp(min = self.eps) * self.g
squared_sum = (x ** 2).sum(dim = -1, keepdim = True)
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * self.gamma * self.scale
def FeedForward(dim, mult = 4, dropout = 0.):
inner_dim = int(mult * dim)
@@ -210,7 +211,7 @@ class ConvNextBlock(nn.Module):
inner_dim = int(dim_out * mult)
self.net = nn.Sequential(
LayerNorm(dim) if norm else nn.Identity(),
RMSNorm(dim) if norm else nn.Identity(),
nn.Conv2d(dim, inner_dim, 3, padding = 1),
nn.GELU(),
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)