Compare commits

...

2 Commits

Author SHA1 Message Date
Phil Wang
591d37e266 lower default initial learning rate to what Jonathan Ho had in his original repo 2022-05-14 13:22:43 -07:00
Phil Wang
d1f02e8f49 always use sandwich norm for attention layer 2022-05-14 12:13:41 -07:00
4 changed files with 6 additions and 9 deletions

View File

@@ -614,7 +614,6 @@ class Attention(nn.Module):
heads = 8, heads = 8,
dropout = 0., dropout = 0.,
causal = False, causal = False,
post_norm = False,
rotary_emb = None rotary_emb = None
): ):
super().__init__() super().__init__()
@@ -624,7 +623,6 @@ class Attention(nn.Module):
self.causal = causal self.causal = causal
self.norm = LayerNorm(dim) self.norm = LayerNorm(dim)
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -635,7 +633,7 @@ class Attention(nn.Module):
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False), nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity() LayerNorm(dim)
) )
def forward(self, x, mask = None, attn_bias = None): def forward(self, x, mask = None, attn_bias = None):
@@ -692,8 +690,7 @@ class Attention(nn.Module):
out = einsum('b h i j, b j d -> b h i d', attn, v) out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out) return self.to_out(out)
return self.post_norm(out)
class CausalTransformer(nn.Module): class CausalTransformer(nn.Module):
def __init__( def __init__(
@@ -719,7 +716,7 @@ class CausalTransformer(nn.Module):
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb), Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer) FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
])) ]))

View File

@@ -7,7 +7,7 @@ def separate_weight_decayable_params(params):
def get_optimizer( def get_optimizer(
params, params,
lr = 3e-4, lr = 2e-5,
wd = 1e-2, wd = 1e-2,
betas = (0.9, 0.999), betas = (0.9, 0.999),
filter_by_requires_grad = False filter_by_requires_grad = False

View File

@@ -221,7 +221,7 @@ class DecoderTrainer(nn.Module):
self, self,
decoder, decoder,
use_ema = True, use_ema = True,
lr = 3e-4, lr = 2e-5,
wd = 1e-2, wd = 1e-2,
max_grad_norm = None, max_grad_norm = None,
amp = False, amp = False,

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.18', version = '0.2.20',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',