Compare commits

..

1 Commits

2 changed files with 3 additions and 1 deletions

View File

@@ -496,6 +496,7 @@ class LayerNorm(nn.Module):
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
x = x / x.amax(dim = -1, keepdim = True).detach()
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g
@@ -507,6 +508,7 @@ class ChanLayerNorm(nn.Module):
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
x = x / x.amax(dim = 1, keepdim = True).detach()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g

View File

@@ -1 +1 @@
__version__ = '0.16.1'
__version__ = '0.16.2'