From b9a908ff7522d52295dbdfd3487c9122e6cd4b98 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 5 Jul 2022 14:27:04 -0700 Subject: [PATCH] bring in two tricks from the cogview paper for reducing the chances of overflow, for attention and layernorm --- dalle2_pytorch/dalle2_pytorch.py | 21 +++++++++++++++++---- dalle2_pytorch/version.py | 2 +- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index e0e14c3..6a45281 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -496,6 +496,7 @@ class LayerNorm(nn.Module): self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): + x = x / x.amax(dim = -1, keepdim = True).detach() var = torch.var(x, dim = -1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = -1, keepdim = True) return (x - mean) * (var + self.eps).rsqrt() * self.g @@ -507,6 +508,7 @@ class ChanLayerNorm(nn.Module): self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) def forward(self, x): + x = x / x.amax(dim = 1, keepdim = True).detach() var = torch.var(x, dim = 1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) * (var + self.eps).rsqrt() * self.g @@ -629,10 +631,13 @@ class Attention(nn.Module): heads = 8, dropout = 0., causal = False, - rotary_emb = None + rotary_emb = None, + pb_relax_alpha = 32 ** 2 ): super().__init__() - self.scale = dim_head ** -0.5 + self.pb_relax_alpha = pb_relax_alpha + self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1) + self.heads = heads inner_dim = dim_head * heads @@ -696,6 +701,9 @@ class Attention(nn.Module): # attention + sim = sim - sim.amax(dim = -1, keepdim = True).detach() + sim = sim * self.pb_relax_alpha + attn = sim.softmax(dim = -1, dtype = torch.float32) attn = self.dropout(attn) @@ -1210,10 +1218,12 @@ class CrossAttention(nn.Module): dim_head = 64, heads = 8, dropout = 0., - norm_context = False + norm_context = False, + pb_relax_alpha = 32 ** 2 ): super().__init__() - self.scale = dim_head ** -0.5 + self.pb_relax_alpha = pb_relax_alpha + self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1) self.heads = heads inner_dim = dim_head * heads @@ -1259,6 +1269,9 @@ class CrossAttention(nn.Module): mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, max_neg_value) + sim = sim - sim.amax(dim = -1, keepdim = True).detach() + sim = sim * self.pb_relax_alpha + attn = sim.softmax(dim = -1, dtype = torch.float32) out = einsum('b h i j, b h j d -> b h i d', attn, v) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 8911e95..e935064 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.0' +__version__ = '0.16.2'