diff --git a/README.md b/README.md index d89ad65..07ad365 100644 --- a/README.md +++ b/README.md @@ -1003,6 +1003,7 @@ Once built, images will be saved to the same directory the command is invoked - [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training - [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images - [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14 +- [x] cross embed layers for downsampling, as an option - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] train on a toy task, offer in colab @@ -1015,7 +1016,6 @@ Once built, images will be saved to the same directory the command is invoked - [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783 - [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes - [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training -- [ ] cross embed layers for downsampling, as an option - [ ] decoder needs one day worth of refactor for tech debt ## Citations diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index af5552e..e27d304 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -41,9 +41,6 @@ def exists(val): def identity(t, *args, **kwargs): return t -def is_odd(n): - return (n % 2) == 1 - def default(val, d): if exists(val): return val @@ -1235,12 +1232,13 @@ class CrossEmbedLayer(nn.Module): def __init__( self, dim_in, - dim_out, kernel_sizes, + dim_out = None, stride = 2 ): super().__init__() - assert all([*map(is_odd, kernel_sizes)]) + assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)]) + dim_out = default(dim_out, dim_in) kernel_sizes = sorted(kernel_sizes) num_scales = len(kernel_sizes) @@ -1282,6 +1280,8 @@ class Unet(nn.Module): init_conv_kernel_size = 7, resnet_groups = 8, init_cross_embed_kernel_sizes = (3, 7, 15), + cross_embed_downsample = False, + cross_embed_downsample_kernel_sizes = (2, 4), **kwargs ): super().__init__() @@ -1302,7 +1302,7 @@ class Unet(nn.Module): init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis init_dim = default(init_dim, dim // 3 * 2) - self.init_conv = CrossEmbedLayer(init_channels, init_dim, init_cross_embed_kernel_sizes, stride = 1) + self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) @@ -1362,6 +1362,12 @@ class Unet(nn.Module): assert len(resnet_groups) == len(in_out) + # downsample klass + + downsample_klass = Downsample + if cross_embed_downsample: + downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes) + # layers self.downs = nn.ModuleList([]) @@ -1377,7 +1383,7 @@ class Unet(nn.Module): ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups), Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(), ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups), - Downsample(dim_out) if not is_last else nn.Identity() + downsample_klass(dim_out) if not is_last else nn.Identity() ])) mid_dim = dims[-1] diff --git a/setup.py b/setup.py index 86d5df6..8eea324 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.7', + version = '0.2.8', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',