mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 00:24:25 +01:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f4fe6c570d | ||
|
|
645e207441 | ||
|
|
00743b3a0b | ||
|
|
01589aff6a | ||
|
|
7ecfd76cc0 | ||
|
|
6161b61c55 | ||
|
|
1ed0f9d80b |
@@ -24,6 +24,8 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
|
||||
|
||||
*ongoing at 21k steps*
|
||||
|
||||
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
|
||||
|
||||
## Pre-Trained Models
|
||||
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
|
||||
@@ -1048,6 +1050,7 @@ This library would not have gotten to this working state without the help of
|
||||
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
||||
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
||||
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
|
||||
|
||||
... and many others. Thank you! 🙏
|
||||
|
||||
@@ -1140,7 +1143,7 @@ This library would not have gotten to this working state without the help of
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
@@ -1196,7 +1199,7 @@ This library would not have gotten to this working state without the help of
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{Saharia2022,
|
||||
@misc{Saharia2022,https://stability.ai/
|
||||
title = {Imagen: unprecedented photorealism × deep level of language understanding},
|
||||
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
|
||||
year = {2022}
|
||||
|
||||
@@ -1346,6 +1346,7 @@ class Unet(nn.Module):
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7,
|
||||
resnet_groups = 8,
|
||||
num_resnet_blocks = 1,
|
||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||
cross_embed_downsample = False,
|
||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||
@@ -1431,6 +1432,7 @@ class Unet(nn.Module):
|
||||
# resnet block klass
|
||||
|
||||
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
||||
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
|
||||
|
||||
assert len(resnet_groups) == len(in_out)
|
||||
|
||||
@@ -1446,7 +1448,7 @@ class Unet(nn.Module):
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
|
||||
is_first = ind == 0
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
@@ -1454,7 +1456,7 @@ class Unet(nn.Module):
|
||||
self.downs.append(nn.ModuleList([
|
||||
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
downsample_klass(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
@@ -1464,14 +1466,14 @@ class Unet(nn.Module):
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||
|
||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
||||
is_last = ind >= (num_resolutions - 2)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
|
||||
@@ -1628,10 +1630,13 @@ class Unet(nn.Module):
|
||||
|
||||
hiddens = []
|
||||
|
||||
for block1, sparse_attn, block2, downsample in self.downs:
|
||||
x = block1(x, c, t)
|
||||
for init_block, sparse_attn, resnet_blocks, downsample in self.downs:
|
||||
x = init_block(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = block2(x, c, t)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = resnet_block(x, c, t)
|
||||
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
@@ -1642,11 +1647,14 @@ class Unet(nn.Module):
|
||||
|
||||
x = self.mid_block2(x, mid_c, t)
|
||||
|
||||
for block1, sparse_attn, block2, upsample in self.ups:
|
||||
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = block1(x, c, t)
|
||||
x = init_block(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = block2(x, c, t)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = resnet_block(x, c, t)
|
||||
|
||||
x = upsample(x)
|
||||
|
||||
return self.final_conv(x)
|
||||
|
||||
@@ -21,8 +21,6 @@ def get_optimizer(
|
||||
if wd == 0:
|
||||
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||
|
||||
params = set(params)
|
||||
|
||||
if group_wd_params:
|
||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.5.3',
|
||||
version = '0.5.5',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -347,7 +347,7 @@ def train(
|
||||
# Compute evaluation metrics
|
||||
if exists(evaluate_config):
|
||||
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
|
||||
tracker.log(evaluation, step=step, verbose=True)
|
||||
|
||||
# Generate sample images
|
||||
|
||||
Reference in New Issue
Block a user