mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-20 13:04:42 +01:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
70282de23b | ||
|
|
83f761847e | ||
|
|
11469dc0c6 | ||
|
|
2d25c89f35 | ||
|
|
3fe96c208a |
11
README.md
11
README.md
@@ -831,6 +831,7 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||||
|
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
@@ -896,4 +897,14 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{Shleifer2021NormFormerIT,
|
||||||
|
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
|
||||||
|
author = {Sam Shleifer and Jason Weston and Myle Ott},
|
||||||
|
journal = {ArXiv},
|
||||||
|
year = {2021},
|
||||||
|
volume = {abs/2110.09456}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||||
|
|||||||
@@ -499,7 +499,12 @@ class SwiGLU(nn.Module):
|
|||||||
x, gate = x.chunk(2, dim = -1)
|
x, gate = x.chunk(2, dim = -1)
|
||||||
return x * F.silu(gate)
|
return x * F.silu(gate)
|
||||||
|
|
||||||
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
|
def FeedForward(
|
||||||
|
dim,
|
||||||
|
mult = 4,
|
||||||
|
dropout = 0.,
|
||||||
|
post_activation_norm = False
|
||||||
|
):
|
||||||
""" post-activation norm https://arxiv.org/abs/2110.09456 """
|
""" post-activation norm https://arxiv.org/abs/2110.09456 """
|
||||||
|
|
||||||
inner_dim = int(mult * dim)
|
inner_dim = int(mult * dim)
|
||||||
@@ -522,7 +527,8 @@ class Attention(nn.Module):
|
|||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
dropout = 0.,
|
dropout = 0.,
|
||||||
causal = False
|
causal = False,
|
||||||
|
post_norm = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
@@ -537,7 +543,11 @@ class Attention(nn.Module):
|
|||||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Linear(inner_dim, dim, bias = False),
|
||||||
|
LayerNorm(dim) if post_norm else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, mask = None, attn_bias = None):
|
def forward(self, x, mask = None, attn_bias = None):
|
||||||
b, n, device = *x.shape[:2], x.device
|
b, n, device = *x.shape[:2], x.device
|
||||||
@@ -599,10 +609,11 @@ class CausalTransformer(nn.Module):
|
|||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
ff_mult = 4,
|
ff_mult = 4,
|
||||||
norm_out = False,
|
norm_out = True,
|
||||||
attn_dropout = 0.,
|
attn_dropout = 0.,
|
||||||
ff_dropout = 0.,
|
ff_dropout = 0.,
|
||||||
final_proj = True
|
final_proj = True,
|
||||||
|
normformer = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||||
@@ -610,8 +621,8 @@ class CausalTransformer(nn.Module):
|
|||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
for _ in range(depth):
|
for _ in range(depth):
|
||||||
self.layers.append(nn.ModuleList([
|
self.layers.append(nn.ModuleList([
|
||||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout),
|
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
|
||||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.0.91',
|
version = '0.0.93',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ def train(image_embed_dim,
|
|||||||
clip,
|
clip,
|
||||||
dp_condition_on_text_encodings,
|
dp_condition_on_text_encodings,
|
||||||
dp_timesteps,
|
dp_timesteps,
|
||||||
|
dp_l2norm_output,
|
||||||
dp_cond_drop_prob,
|
dp_cond_drop_prob,
|
||||||
dpn_depth,
|
dpn_depth,
|
||||||
dpn_dim_head,
|
dpn_dim_head,
|
||||||
@@ -70,7 +71,8 @@ def train(image_embed_dim,
|
|||||||
dim = image_embed_dim,
|
dim = image_embed_dim,
|
||||||
depth = dpn_depth,
|
depth = dpn_depth,
|
||||||
dim_head = dpn_dim_head,
|
dim_head = dpn_dim_head,
|
||||||
heads = dpn_heads).to(device)
|
heads = dpn_heads,
|
||||||
|
l2norm_output = dp_l2norm_output).to(device)
|
||||||
|
|
||||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||||
diffusion_prior = DiffusionPrior(
|
diffusion_prior = DiffusionPrior(
|
||||||
@@ -180,6 +182,7 @@ def main():
|
|||||||
# DiffusionPrior(dp) parameters
|
# DiffusionPrior(dp) parameters
|
||||||
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
||||||
parser.add_argument("--dp-timesteps", type=int, default=100)
|
parser.add_argument("--dp-timesteps", type=int, default=100)
|
||||||
|
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
|
||||||
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
|
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
|
||||||
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
||||||
parser.add_argument("--clip", type=str, default=None)
|
parser.add_argument("--clip", type=str, default=None)
|
||||||
@@ -223,6 +226,7 @@ def main():
|
|||||||
args.clip,
|
args.clip,
|
||||||
args.dp_condition_on_text_encodings,
|
args.dp_condition_on_text_encodings,
|
||||||
args.dp_timesteps,
|
args.dp_timesteps,
|
||||||
|
args.dp_l2norm_output,
|
||||||
args.dp_cond_drop_prob,
|
args.dp_cond_drop_prob,
|
||||||
args.dpn_depth,
|
args.dpn_depth,
|
||||||
args.dpn_dim_head,
|
args.dpn_dim_head,
|
||||||
|
|||||||
Reference in New Issue
Block a user