mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
change up epsilon in layernorm the case of using fp16, thanks to @Veldrovive for figuring out this stabilizes training
This commit is contained in:
@@ -547,34 +547,40 @@ class NoiseScheduler(nn.Module):
|
|||||||
# diffusion prior
|
# diffusion prior
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
def __init__(self, dim, eps = 1e-5, stable = False):
|
def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
self.fp16_eps = fp16_eps
|
||||||
self.stable = stable
|
self.stable = stable
|
||||||
self.g = nn.Parameter(torch.ones(dim))
|
self.g = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
|
||||||
|
|
||||||
if self.stable:
|
if self.stable:
|
||||||
x = x / x.amax(dim = -1, keepdim = True).detach()
|
x = x / x.amax(dim = -1, keepdim = True).detach()
|
||||||
|
|
||||||
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
||||||
mean = torch.mean(x, dim = -1, keepdim = True)
|
mean = torch.mean(x, dim = -1, keepdim = True)
|
||||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
return (x - mean) * (var + eps).rsqrt() * self.g
|
||||||
|
|
||||||
class ChanLayerNorm(nn.Module):
|
class ChanLayerNorm(nn.Module):
|
||||||
def __init__(self, dim, eps = 1e-5, stable = False):
|
def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
self.fp16_eps = fp16_eps
|
||||||
self.stable = stable
|
self.stable = stable
|
||||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
|
||||||
|
|
||||||
if self.stable:
|
if self.stable:
|
||||||
x = x / x.amax(dim = 1, keepdim = True).detach()
|
x = x / x.amax(dim = 1, keepdim = True).detach()
|
||||||
|
|
||||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
return (x - mean) * (var + eps).rsqrt() * self.g
|
||||||
|
|
||||||
class Residual(nn.Module):
|
class Residual(nn.Module):
|
||||||
def __init__(self, fn):
|
def __init__(self, fn):
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.4.0'
|
__version__ = '1.4.2'
|
||||||
|
|||||||
Reference in New Issue
Block a user