Compare commits

...

2 Commits

Author SHA1 Message Date
Phil Wang
70282de23b add ability to turn on normformer settings, given @borisdayma reported good results and some personal anecdata 2022-05-02 11:33:15 -07:00
Phil Wang
83f761847e todo 2022-05-02 10:52:39 -07:00
3 changed files with 29 additions and 7 deletions

View File

@@ -831,6 +831,7 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
## Citations
@@ -896,4 +897,14 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@article{Shleifer2021NormFormerIT,
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
author = {Sam Shleifer and Jason Weston and Myle Ott},
journal = {ArXiv},
year = {2021},
volume = {abs/2110.09456}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>

View File

@@ -499,7 +499,12 @@ class SwiGLU(nn.Module):
x, gate = x.chunk(2, dim = -1)
return x * F.silu(gate)
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
def FeedForward(
dim,
mult = 4,
dropout = 0.,
post_activation_norm = False
):
""" post-activation norm https://arxiv.org/abs/2110.09456 """
inner_dim = int(mult * dim)
@@ -522,7 +527,8 @@ class Attention(nn.Module):
dim_head = 64,
heads = 8,
dropout = 0.,
causal = False
causal = False,
post_norm = False
):
super().__init__()
self.scale = dim_head ** -0.5
@@ -537,7 +543,11 @@ class Attention(nn.Module):
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity()
)
def forward(self, x, mask = None, attn_bias = None):
b, n, device = *x.shape[:2], x.device
@@ -602,7 +612,8 @@ class CausalTransformer(nn.Module):
norm_out = True,
attn_dropout = 0.,
ff_dropout = 0.,
final_proj = True
final_proj = True,
normformer = False
):
super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads)
@@ -610,8 +621,8 @@ class CausalTransformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.92',
version = '0.0.93',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',