mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
591d37e266 | ||
|
|
d1f02e8f49 | ||
|
|
9faab59b23 |
@@ -614,7 +614,6 @@ class Attention(nn.Module):
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
causal = False,
|
||||
post_norm = False,
|
||||
rotary_emb = None
|
||||
):
|
||||
super().__init__()
|
||||
@@ -624,7 +623,6 @@ class Attention(nn.Module):
|
||||
|
||||
self.causal = causal
|
||||
self.norm = LayerNorm(dim)
|
||||
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
@@ -635,7 +633,7 @@ class Attention(nn.Module):
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
LayerNorm(dim) if post_norm else nn.Identity()
|
||||
LayerNorm(dim)
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None, attn_bias = None):
|
||||
@@ -692,8 +690,7 @@ class Attention(nn.Module):
|
||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
return self.post_norm(out)
|
||||
return self.to_out(out)
|
||||
|
||||
class CausalTransformer(nn.Module):
|
||||
def __init__(
|
||||
@@ -719,7 +716,7 @@ class CausalTransformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||
]))
|
||||
|
||||
@@ -1181,7 +1178,11 @@ class CrossAttention(nn.Module):
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
LayerNorm(dim)
|
||||
)
|
||||
|
||||
def forward(self, x, context, mask = None):
|
||||
b, n, device = *x.shape[:2], x.device
|
||||
|
||||
@@ -7,7 +7,7 @@ def separate_weight_decayable_params(params):
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
lr = 3e-4,
|
||||
lr = 2e-5,
|
||||
wd = 1e-2,
|
||||
betas = (0.9, 0.999),
|
||||
filter_by_requires_grad = False
|
||||
|
||||
@@ -221,7 +221,7 @@ class DecoderTrainer(nn.Module):
|
||||
self,
|
||||
decoder,
|
||||
use_ema = True,
|
||||
lr = 3e-4,
|
||||
lr = 2e-5,
|
||||
wd = 1e-2,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
|
||||
Reference in New Issue
Block a user