From 6a11b9678bc0b52c492d8e721cf02614fc7804d7 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 26 Jun 2022 21:59:55 -0700 Subject: [PATCH] bring in the skip connection scaling factor, used by imagen in their unets, cite original paper using it --- README.md | 10 ++++++++++ dalle2_pytorch/dalle2_pytorch.py | 9 ++++++++- dalle2_pytorch/version.py | 2 +- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a9c5e9e..8d89ab6 100644 --- a/README.md +++ b/README.md @@ -1189,4 +1189,14 @@ Once built, images will be saved to the same directory the command is invoked } ``` +```bibtex +@article{Saharia2021PaletteID, + title = {Palette: Image-to-Image Diffusion Models}, + author = {Chitwan Saharia and William Chan and Huiwen Chang and Chris A. Lee and Jonathan Ho and Tim Salimans and David J. Fleet and Mohammad Norouzi}, + journal = {ArXiv}, + year = {2021}, + volume = {abs/2111.05826} +} +``` + *Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 3ed29c0..65f53a7 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1359,6 +1359,7 @@ class Unet(nn.Module): cross_embed_downsample = False, cross_embed_downsample_kernel_sizes = (2, 4), memory_efficient = False, + scale_skip_connection = False, **kwargs ): super().__init__() @@ -1440,6 +1441,10 @@ class Unet(nn.Module): self.max_text_len = max_text_len self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) + # whether to scale skip connection, adopted in Imagen + + self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5) + # attention related params attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head) @@ -1687,7 +1692,9 @@ class Unet(nn.Module): x = self.mid_block2(x, mid_c, t) for init_block, sparse_attn, resnet_blocks, upsample in self.ups: - x = torch.cat((x, hiddens.pop()), dim = 1) + skip_connect = hiddens.pop() * self.skip_connect_scale + + x = torch.cat((x, skip_connect), dim = 1) x = init_block(x, c, t) x = sparse_attn(x) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 92a60bd..c4e914a 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.12.2' +__version__ = '0.12.3'