Compare commits

..

3 Commits

Author SHA1 Message Date
Phil Wang
46a2558d53 bug in pydantic decoder config class 2022-06-29 07:17:35 -07:00
yytdfc
86109646e3 fix a bug of name error (#179) 2022-06-29 07:16:44 -07:00
Phil Wang
6a11b9678b bring in the skip connection scaling factor, used by imagen in their unets, cite original paper using it 2022-06-26 21:59:55 -07:00
4 changed files with 20 additions and 3 deletions

View File

@@ -1189,4 +1189,14 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@article{Saharia2021PaletteID,
title = {Palette: Image-to-Image Diffusion Models},
author = {Chitwan Saharia and William Chan and Huiwen Chang and Chris A. Lee and Jonathan Ho and Tim Salimans and David J. Fleet and Mohammad Norouzi},
journal = {ArXiv},
year = {2021},
volume = {abs/2111.05826}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1359,6 +1359,7 @@ class Unet(nn.Module):
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
memory_efficient = False,
scale_skip_connection = False,
**kwargs
):
super().__init__()
@@ -1440,6 +1441,10 @@ class Unet(nn.Module):
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
# whether to scale skip connection, adopted in Imagen
self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
# attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
@@ -1687,7 +1692,9 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c, t)
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim = 1)
skip_connect = hiddens.pop() * self.skip_connect_scale
x = torch.cat((x, skip_connect), dim = 1)
x = init_block(x, c, t)
x = sparse_attn(x)

View File

@@ -289,7 +289,7 @@ class TrainDecoderConfig(BaseModel):
# Then something else errored and we should just pass through
return values
using_text_encodings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])
using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])
using_clip = exists(decoder_config.clip)
img_emb_url = data_config.img_embeddings_url
text_emb_url = data_config.text_embeddings_url

View File

@@ -1 +1 @@
__version__ = '0.12.2'
__version__ = '0.12.4'