mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 21:24:25 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
591d37e266 | ||
|
|
d1f02e8f49 | ||
|
|
9faab59b23 |
@@ -614,7 +614,6 @@ class Attention(nn.Module):
|
|||||||
heads = 8,
|
heads = 8,
|
||||||
dropout = 0.,
|
dropout = 0.,
|
||||||
causal = False,
|
causal = False,
|
||||||
post_norm = False,
|
|
||||||
rotary_emb = None
|
rotary_emb = None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -624,7 +623,6 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.norm = LayerNorm(dim)
|
self.norm = LayerNorm(dim)
|
||||||
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||||
@@ -635,7 +633,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
nn.Linear(inner_dim, dim, bias = False),
|
nn.Linear(inner_dim, dim, bias = False),
|
||||||
LayerNorm(dim) if post_norm else nn.Identity()
|
LayerNorm(dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, mask = None, attn_bias = None):
|
def forward(self, x, mask = None, attn_bias = None):
|
||||||
@@ -692,8 +690,7 @@ class Attention(nn.Module):
|
|||||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||||
|
|
||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
out = self.to_out(out)
|
return self.to_out(out)
|
||||||
return self.post_norm(out)
|
|
||||||
|
|
||||||
class CausalTransformer(nn.Module):
|
class CausalTransformer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -719,7 +716,7 @@ class CausalTransformer(nn.Module):
|
|||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
for _ in range(depth):
|
for _ in range(depth):
|
||||||
self.layers.append(nn.ModuleList([
|
self.layers.append(nn.ModuleList([
|
||||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
|
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
|
||||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
@@ -1181,7 +1178,11 @@ class CrossAttention(nn.Module):
|
|||||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
||||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Linear(inner_dim, dim, bias = False),
|
||||||
|
LayerNorm(dim)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, context, mask = None):
|
def forward(self, x, context, mask = None):
|
||||||
b, n, device = *x.shape[:2], x.device
|
b, n, device = *x.shape[:2], x.device
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ def separate_weight_decayable_params(params):
|
|||||||
|
|
||||||
def get_optimizer(
|
def get_optimizer(
|
||||||
params,
|
params,
|
||||||
lr = 3e-4,
|
lr = 2e-5,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
betas = (0.9, 0.999),
|
betas = (0.9, 0.999),
|
||||||
filter_by_requires_grad = False
|
filter_by_requires_grad = False
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
decoder,
|
decoder,
|
||||||
use_ema = True,
|
use_ema = True,
|
||||||
lr = 3e-4,
|
lr = 2e-5,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
amp = False,
|
amp = False,
|
||||||
|
|||||||
Reference in New Issue
Block a user