change up epsilon in layernorm the case of using fp16, thanks to @Veldrovive for figuring out this stabilizes training

This commit is contained in:
Phil Wang
2022-07-29 12:41:02 -07:00
parent 748c7fe7af
commit 2d67d5821e
2 changed files with 11 additions and 5 deletions

View File

@@ -547,34 +547,40 @@ class NoiseScheduler(nn.Module):
# diffusion prior # diffusion prior
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5, stable = False): def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.fp16_eps = fp16_eps
self.stable = stable self.stable = stable
self.g = nn.Parameter(torch.ones(dim)) self.g = nn.Parameter(torch.ones(dim))
def forward(self, x): def forward(self, x):
eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
if self.stable: if self.stable:
x = x / x.amax(dim = -1, keepdim = True).detach() x = x / x.amax(dim = -1, keepdim = True).detach()
var = torch.var(x, dim = -1, unbiased = False, keepdim = True) var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True) mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g return (x - mean) * (var + eps).rsqrt() * self.g
class ChanLayerNorm(nn.Module): class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5, stable = False): def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.fp16_eps = fp16_eps
self.stable = stable self.stable = stable
self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x): def forward(self, x):
eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
if self.stable: if self.stable:
x = x / x.amax(dim = 1, keepdim = True).detach() x = x / x.amax(dim = 1, keepdim = True).detach()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True) var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True) mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g return (x - mean) * (var + eps).rsqrt() * self.g
class Residual(nn.Module): class Residual(nn.Module):
def __init__(self, fn): def __init__(self, fn):

View File

@@ -1 +1 @@
__version__ = '1.4.0' __version__ = '1.4.2'