diff --git a/README.md b/README.md index 8235ea0..a5329e0 100644 --- a/README.md +++ b/README.md @@ -1002,12 +1002,12 @@ Once built, images will be saved to the same directory the command is invoked - [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor) - [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training - [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images +- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14 - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] train on a toy task, offer in colab - [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder - [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference -- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14 - [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824 - [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove - [ ] use an experimental tracker agnostic setup, as done here @@ -1093,4 +1093,15 @@ Once built, images will be saved to the same directory the command is invoked } ``` +```bibtex +@misc{wang2021crossformer, + title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention}, + author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu}, + year = {2021}, + eprint = {2108.00154}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` + *Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 611f0c1..af5552e 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -41,6 +41,9 @@ def exists(val): def identity(t, *args, **kwargs): return t +def is_odd(n): + return (n % 2) == 1 + def default(val, d): if exists(val): return val @@ -1228,6 +1231,32 @@ class LinearAttention(nn.Module): out = self.nonlin(out) return self.to_out(out) +class CrossEmbedLayer(nn.Module): + def __init__( + self, + dim_in, + dim_out, + kernel_sizes, + stride = 2 + ): + super().__init__() + assert all([*map(is_odd, kernel_sizes)]) + + kernel_sizes = sorted(kernel_sizes) + num_scales = len(kernel_sizes) + + # calculate the dimension at each scale + dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)] + dim_scales = [*dim_scales, dim_out - sum(dim_scales)] + + self.convs = nn.ModuleList([]) + for kernel, dim_scale in zip(kernel_sizes, dim_scales): + self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2)) + + def forward(self, x): + fmaps = tuple(map(lambda conv: conv(x), self.convs)) + return torch.cat(fmaps, dim = 1) + class Unet(nn.Module): def __init__( self, @@ -1252,6 +1281,7 @@ class Unet(nn.Module): init_dim = None, init_conv_kernel_size = 7, resnet_groups = 8, + init_cross_embed_kernel_sizes = (3, 7, 15), **kwargs ): super().__init__() @@ -1270,10 +1300,9 @@ class Unet(nn.Module): self.channels = channels init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis - init_dim = default(init_dim, dim // 2) + init_dim = default(init_dim, dim // 3 * 2) - assert (init_conv_kernel_size % 2) == 1 - self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2) + self.init_conv = CrossEmbedLayer(init_channels, init_dim, init_cross_embed_kernel_sizes, stride = 1) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) diff --git a/setup.py b/setup.py index 3a419af..86d5df6 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.6', + version = '0.2.7', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',