mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 23:34:20 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bdf85a5e9 |
@@ -496,7 +496,6 @@ class LayerNorm(nn.Module):
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = x / x.amax(dim = -1, keepdim = True).detach()
|
||||
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = -1, keepdim = True)
|
||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
||||
@@ -508,7 +507,6 @@ class ChanLayerNorm(nn.Module):
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = x / x.amax(dim = 1, keepdim = True).detach()
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.16.2'
|
||||
__version__ = '0.16.1'
|
||||
|
||||
Reference in New Issue
Block a user