From 3b2cf7b0bc152d826f74a90f5f6b922a8b9f7b21 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 11 Feb 2023 17:18:40 -0800 Subject: [PATCH] fix for self conditioning in diffusion prior network https://github.com/lucidrains/DALLE2-pytorch/issues/273 --- dalle2_pytorch/dalle2_pytorch.py | 2 +- dalle2_pytorch/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 9524eac..021d5a1 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1124,7 +1124,7 @@ class DiffusionPriorNetwork(nn.Module): learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) if self.self_cond: - learned_queries = torch.cat((image_embed, self_cond), dim = -2) + learned_queries = torch.cat((self_cond, learned_queries), dim = -2) tokens = torch.cat(( text_encodings, diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 666b2f7..fe70fa2 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.12.0' +__version__ = '1.12.1'