From ff3474f05c404c41e201c03d2a5722b9f9675ab7 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 14 May 2022 14:23:52 -0700 Subject: [PATCH] normalize conditioning tokens outside of cross attention blocks --- README.md | 1 + dalle2_pytorch/dalle2_pytorch.py | 11 ++++++++++- setup.py | 2 +- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 5c9fc16..b1a83c0 100644 --- a/README.md +++ b/README.md @@ -1017,6 +1017,7 @@ Once built, images will be saved to the same directory the command is invoked - [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes - [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training - [ ] decoder needs one day worth of refactor for tech debt +- [ ] allow for unet to be able to condition non-cross attention style as well ## Citations diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 9f5499f..cb77c78 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1163,6 +1163,7 @@ class CrossAttention(nn.Module): dim_head = 64, heads = 8, dropout = 0., + norm_context = False ): super().__init__() self.scale = dim_head ** -0.5 @@ -1172,7 +1173,7 @@ class CrossAttention(nn.Module): context_dim = default(context_dim, dim) self.norm = LayerNorm(dim) - self.norm_context = LayerNorm(context_dim) + self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity() self.dropout = nn.Dropout(dropout) self.null_kv = nn.Parameter(torch.randn(2, dim_head)) @@ -1378,6 +1379,9 @@ class Unet(nn.Module): Rearrange('b (n d) -> b n d', n = num_image_tokens) ) if image_embed_dim != cond_dim else nn.Identity() + self.norm_cond = nn.LayerNorm(cond_dim) + self.norm_mid_cond = nn.LayerNorm(cond_dim) + # text encoding conditioning (optional) self.text_to_cond = None @@ -1593,6 +1597,11 @@ class Unet(nn.Module): mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2) + # normalize conditioning tokens + + c = self.norm_cond(c) + mid_c = self.norm_mid_cond(mid_c) + # go through the layers of the unet, down and up hiddens = [] diff --git a/setup.py b/setup.py index e8ef76d..70e44b1 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.22', + version = '0.2.23', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',