diff --git a/README.md b/README.md index 837860c..e9bed10 100644 --- a/README.md +++ b/README.md @@ -71,3 +71,14 @@ Todo year = {2022} } ``` + +```bibtex +@misc{zhang2019root, + title = {Root Mean Square Layer Normalization}, + author = {Biao Zhang and Rico Sennrich}, + year = {2019}, + eprint = {1910.07467}, + archivePrefix = {arXiv}, + primaryClass = {cs.LG} +} +``` diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 42900d8..db145aa 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -45,15 +45,16 @@ def freeze_model_and_make_eval_(model): # diffusion prior class RMSNorm(nn.Module): - def __init__(self, dim, eps = 1e-8): + def __init__(self, dim, eps = 1e-5): super().__init__() - self.scale = dim ** -0.5 self.eps = eps - self.g = nn.Parameter(torch.ones(dim)) + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) def forward(self, x): - norm = torch.norm(x, dim = -1, keepdim = True) * self.scale - return x / norm.clamp(min = self.eps) * self.g + squared_sum = (x ** 2).sum(dim = -1, keepdim = True) + inv_norm = torch.rsqrt(squared_sum + self.eps) + return x * inv_norm * self.gamma * self.scale def FeedForward(dim, mult = 4, dropout = 0.): inner_dim = int(mult * dim) @@ -210,7 +211,7 @@ class ConvNextBlock(nn.Module): inner_dim = int(dim_out * mult) self.net = nn.Sequential( - LayerNorm(dim) if norm else nn.Identity(), + RMSNorm(dim) if norm else nn.Identity(), nn.Conv2d(dim, inner_dim, 3, padding = 1), nn.GELU(), nn.Conv2d(inner_dim, dim_out, 3, padding = 1)