mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 16:54:46 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
908088cfea | ||
|
|
8dc8a3de0d | ||
|
|
35f89556ba |
14
README.md
14
README.md
@@ -1002,12 +1002,13 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
|
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
|
||||||
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
|
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
|
||||||
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
||||||
|
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||||
|
- [x] cross embed layers for downsampling, as an option
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||||
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
|
||||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||||
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||||
@@ -1093,4 +1094,15 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{wang2021crossformer,
|
||||||
|
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
|
||||||
|
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
|
||||||
|
year = {2021},
|
||||||
|
eprint = {2108.00154},
|
||||||
|
archivePrefix = {arXiv},
|
||||||
|
primaryClass = {cs.CV}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||||
|
|||||||
@@ -1228,6 +1228,33 @@ class LinearAttention(nn.Module):
|
|||||||
out = self.nonlin(out)
|
out = self.nonlin(out)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
class CrossEmbedLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_in,
|
||||||
|
kernel_sizes,
|
||||||
|
dim_out = None,
|
||||||
|
stride = 2
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
|
||||||
|
dim_out = default(dim_out, dim_in)
|
||||||
|
|
||||||
|
kernel_sizes = sorted(kernel_sizes)
|
||||||
|
num_scales = len(kernel_sizes)
|
||||||
|
|
||||||
|
# calculate the dimension at each scale
|
||||||
|
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
|
||||||
|
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList([])
|
||||||
|
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
|
||||||
|
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fmaps = tuple(map(lambda conv: conv(x), self.convs))
|
||||||
|
return torch.cat(fmaps, dim = 1)
|
||||||
|
|
||||||
class Unet(nn.Module):
|
class Unet(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -1252,6 +1279,9 @@ class Unet(nn.Module):
|
|||||||
init_dim = None,
|
init_dim = None,
|
||||||
init_conv_kernel_size = 7,
|
init_conv_kernel_size = 7,
|
||||||
resnet_groups = 8,
|
resnet_groups = 8,
|
||||||
|
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||||
|
cross_embed_downsample = False,
|
||||||
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1270,10 +1300,9 @@ class Unet(nn.Module):
|
|||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
|
||||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||||
init_dim = default(init_dim, dim // 2)
|
init_dim = default(init_dim, dim // 3 * 2)
|
||||||
|
|
||||||
assert (init_conv_kernel_size % 2) == 1
|
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
||||||
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
|
||||||
|
|
||||||
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||||
in_out = list(zip(dims[:-1], dims[1:]))
|
in_out = list(zip(dims[:-1], dims[1:]))
|
||||||
@@ -1333,6 +1362,12 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
assert len(resnet_groups) == len(in_out)
|
assert len(resnet_groups) == len(in_out)
|
||||||
|
|
||||||
|
# downsample klass
|
||||||
|
|
||||||
|
downsample_klass = Downsample
|
||||||
|
if cross_embed_downsample:
|
||||||
|
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
||||||
|
|
||||||
# layers
|
# layers
|
||||||
|
|
||||||
self.downs = nn.ModuleList([])
|
self.downs = nn.ModuleList([])
|
||||||
@@ -1348,7 +1383,7 @@ class Unet(nn.Module):
|
|||||||
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||||
Downsample(dim_out) if not is_last else nn.Identity()
|
downsample_klass(dim_out) if not is_last else nn.Identity()
|
||||||
]))
|
]))
|
||||||
|
|
||||||
mid_dim = dims[-1]
|
mid_dim = dims[-1]
|
||||||
|
|||||||
Reference in New Issue
Block a user