diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index ab18a19..16f8095 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -137,23 +137,27 @@ def sigmoid_beta_schedule(timesteps): # diffusion prior -class RMSNorm(nn.Module): +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.ones(dim)) + self.register_buffer("beta", torch.zeros(dim)) + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) + + +class ChanLayerNorm(nn.Module): def __init__(self, dim, eps = 1e-5): super().__init__() self.eps = eps - self.scale = dim ** 0.5 - self.gamma = nn.Parameter(torch.ones(dim)) + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) def forward(self, x): - squared_sum = (x ** 2).sum(dim = -1, keepdim = True) - inv_norm = torch.rsqrt(squared_sum + self.eps) - return x * inv_norm * self.gamma * self.scale + var = torch.var(x, dim = 1, unbiased = False, keepdim = True) + mean = torch.mean(x, dim = 1, keepdim = True) + return (x - mean) / (var + self.eps).sqrt() * self.g -class ChanRMSNorm(RMSNorm): - def forward(self, x): - squared_sum = (x ** 2).sum(dim = 1, keepdim = True) - inv_norm = torch.rsqrt(squared_sum + self.eps) - return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale class Residual(nn.Module): def __init__(self, fn): @@ -249,10 +253,10 @@ def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False): inner_dim = int(mult * dim) return nn.Sequential( - RMSNorm(dim), + LayerNorm(dim), nn.Linear(dim, inner_dim * 2, bias = False), SwiGLU(), - RMSNorm(inner_dim) if post_activation_norm else nn.Identity(), + LayerNorm(inner_dim) if post_activation_norm else nn.Identity(), nn.Dropout(dropout), nn.Linear(inner_dim, dim, bias = False) ) @@ -275,7 +279,8 @@ class Attention(nn.Module): inner_dim = dim_head * heads self.causal = causal - self.norm = RMSNorm(dim) + self.norm = LayerNorm(dim) + self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer self.dropout = nn.Dropout(dropout) self.null_kv = nn.Parameter(torch.randn(2, dim_head)) @@ -331,7 +336,8 @@ class Attention(nn.Module): out = einsum('b h i j, b j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') - return self.to_out(out) + out = self.to_out(out) + return self.post_norm(out) class CausalTransformer(nn.Module): def __init__( @@ -356,7 +362,7 @@ class CausalTransformer(nn.Module): FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) ])) - self.norm = RMSNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options + self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options def forward( self, @@ -720,7 +726,7 @@ class ConvNextBlock(nn.Module): inner_dim = int(dim_out * mult) self.net = nn.Sequential( - ChanRMSNorm(dim) if norm else nn.Identity(), + ChanLayerNorm(dim) if norm else nn.Identity(), nn.Conv2d(dim, inner_dim, 3, padding = 1), nn.GELU(), nn.Conv2d(inner_dim, dim_out, 3, padding = 1) @@ -756,8 +762,8 @@ class CrossAttention(nn.Module): context_dim = default(context_dim, dim) - self.norm = RMSNorm(dim) - self.norm_context = RMSNorm(context_dim) + self.norm = LayerNorm(dim) + self.norm_context = LayerNorm(context_dim) self.dropout = nn.Dropout(dropout) self.null_kv = nn.Parameter(torch.randn(2, dim_head)) diff --git a/setup.py b/setup.py index 71a3e49..be472ca 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.33', + version = '0.0.34', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',