diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 6fb02e8..1a44a74 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -527,25 +527,31 @@ class NoiseScheduler(nn.Module): # diffusion prior class LayerNorm(nn.Module): - def __init__(self, dim, eps = 1e-5): + def __init__(self, dim, eps = 1e-5, stable = False): super().__init__() self.eps = eps + self.stable = stable self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): - x = x / x.amax(dim = -1, keepdim = True).detach() + if self.stable: + x = x / x.amax(dim = -1, keepdim = True).detach() + var = torch.var(x, dim = -1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = -1, keepdim = True) return (x - mean) * (var + self.eps).rsqrt() * self.g class ChanLayerNorm(nn.Module): - def __init__(self, dim, eps = 1e-5): + def __init__(self, dim, eps = 1e-5, stable = False): super().__init__() self.eps = eps + self.stable = stable self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) def forward(self, x): - x = x / x.amax(dim = 1, keepdim = True).detach() + if self.stable: + x = x / x.amax(dim = 1, keepdim = True).detach() + var = torch.var(x, dim = 1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) * (var + self.eps).rsqrt() * self.g @@ -669,7 +675,7 @@ class Attention(nn.Module): dropout = 0., causal = False, rotary_emb = None, - pb_relax_alpha = 32 ** 2 + pb_relax_alpha = 128 ): super().__init__() self.pb_relax_alpha = pb_relax_alpha @@ -782,7 +788,7 @@ class CausalTransformer(nn.Module): FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer) ])) - self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options + self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity() def forward(self, x): diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 5a983c9..7e8f349 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.23.2' +__version__ = '0.23.3'