diff --git a/README.md b/README.md
index 8235ea0..a5329e0 100644
--- a/README.md
+++ b/README.md
@@ -1002,12 +1002,12 @@ Once built, images will be saved to the same directory the command is invoked
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
+- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
-- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done here
@@ -1093,4 +1093,15 @@ Once built, images will be saved to the same directory the command is invoked
}
```
+```bibtex
+@misc{wang2021crossformer,
+ title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
+ author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
+ year = {2021},
+ eprint = {2108.00154},
+ archivePrefix = {arXiv},
+ primaryClass = {cs.CV}
+}
+```
+
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper
diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py
index 611f0c1..af5552e 100644
--- a/dalle2_pytorch/dalle2_pytorch.py
+++ b/dalle2_pytorch/dalle2_pytorch.py
@@ -41,6 +41,9 @@ def exists(val):
def identity(t, *args, **kwargs):
return t
+def is_odd(n):
+ return (n % 2) == 1
+
def default(val, d):
if exists(val):
return val
@@ -1228,6 +1231,32 @@ class LinearAttention(nn.Module):
out = self.nonlin(out)
return self.to_out(out)
+class CrossEmbedLayer(nn.Module):
+ def __init__(
+ self,
+ dim_in,
+ dim_out,
+ kernel_sizes,
+ stride = 2
+ ):
+ super().__init__()
+ assert all([*map(is_odd, kernel_sizes)])
+
+ kernel_sizes = sorted(kernel_sizes)
+ num_scales = len(kernel_sizes)
+
+ # calculate the dimension at each scale
+ dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
+ dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
+
+ self.convs = nn.ModuleList([])
+ for kernel, dim_scale in zip(kernel_sizes, dim_scales):
+ self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
+
+ def forward(self, x):
+ fmaps = tuple(map(lambda conv: conv(x), self.convs))
+ return torch.cat(fmaps, dim = 1)
+
class Unet(nn.Module):
def __init__(
self,
@@ -1252,6 +1281,7 @@ class Unet(nn.Module):
init_dim = None,
init_conv_kernel_size = 7,
resnet_groups = 8,
+ init_cross_embed_kernel_sizes = (3, 7, 15),
**kwargs
):
super().__init__()
@@ -1270,10 +1300,9 @@ class Unet(nn.Module):
self.channels = channels
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
- init_dim = default(init_dim, dim // 2)
+ init_dim = default(init_dim, dim // 3 * 2)
- assert (init_conv_kernel_size % 2) == 1
- self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
+ self.init_conv = CrossEmbedLayer(init_channels, init_dim, init_cross_embed_kernel_sizes, stride = 1)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
diff --git a/setup.py b/setup.py
index 3a419af..86d5df6 100644
--- a/setup.py
+++ b/setup.py
@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
- version = '0.2.6',
+ version = '0.2.7',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',