mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-16 09:34:23 +01:00
move epsilon inside of square root for further stability in rmsnorm
improvise and use rmsnorm in convnext blocks too
This commit is contained in:
11
README.md
11
README.md
@@ -71,3 +71,14 @@ Todo
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{zhang2019root,
|
||||
title = {Root Mean Square Layer Normalization},
|
||||
author = {Biao Zhang and Rico Sennrich},
|
||||
year = {2019},
|
||||
eprint = {1910.07467},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -45,15 +45,16 @@ def freeze_model_and_make_eval_(model):
|
||||
# diffusion prior
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-8):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
|
||||
return x / norm.clamp(min = self.eps) * self.g
|
||||
squared_sum = (x ** 2).sum(dim = -1, keepdim = True)
|
||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||
return x * inv_norm * self.gamma * self.scale
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
inner_dim = int(mult * dim)
|
||||
@@ -210,7 +211,7 @@ class ConvNextBlock(nn.Module):
|
||||
|
||||
inner_dim = int(dim_out * mult)
|
||||
self.net = nn.Sequential(
|
||||
LayerNorm(dim) if norm else nn.Identity(),
|
||||
RMSNorm(dim) if norm else nn.Identity(),
|
||||
nn.Conv2d(dim, inner_dim, 3, padding = 1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
|
||||
|
||||
Reference in New Issue
Block a user