mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
bring in the skip connection scaling factor, used by imagen in their unets, cite original paper using it
This commit is contained in:
10
README.md
10
README.md
@@ -1189,4 +1189,14 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{Saharia2021PaletteID,
|
||||||
|
title = {Palette: Image-to-Image Diffusion Models},
|
||||||
|
author = {Chitwan Saharia and William Chan and Huiwen Chang and Chris A. Lee and Jonathan Ho and Tim Salimans and David J. Fleet and Mohammad Norouzi},
|
||||||
|
journal = {ArXiv},
|
||||||
|
year = {2021},
|
||||||
|
volume = {abs/2111.05826}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||||
|
|||||||
@@ -1359,6 +1359,7 @@ class Unet(nn.Module):
|
|||||||
cross_embed_downsample = False,
|
cross_embed_downsample = False,
|
||||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||||
memory_efficient = False,
|
memory_efficient = False,
|
||||||
|
scale_skip_connection = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1440,6 +1441,10 @@ class Unet(nn.Module):
|
|||||||
self.max_text_len = max_text_len
|
self.max_text_len = max_text_len
|
||||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||||
|
|
||||||
|
# whether to scale skip connection, adopted in Imagen
|
||||||
|
|
||||||
|
self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
|
||||||
|
|
||||||
# attention related params
|
# attention related params
|
||||||
|
|
||||||
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||||
@@ -1687,7 +1692,9 @@ class Unet(nn.Module):
|
|||||||
x = self.mid_block2(x, mid_c, t)
|
x = self.mid_block2(x, mid_c, t)
|
||||||
|
|
||||||
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||||
x = torch.cat((x, hiddens.pop()), dim = 1)
|
skip_connect = hiddens.pop() * self.skip_connect_scale
|
||||||
|
|
||||||
|
x = torch.cat((x, skip_connect), dim = 1)
|
||||||
x = init_block(x, c, t)
|
x = init_block(x, c, t)
|
||||||
x = sparse_attn(x)
|
x = sparse_attn(x)
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.12.2'
|
__version__ = '0.12.3'
|
||||||
|
|||||||
Reference in New Issue
Block a user