fix for self conditioning in diffusion prior network https://github.com/lucidrains/DALLE2-pytorch/issues/273

This commit is contained in:
Phil Wang
2023-02-11 17:18:40 -08:00
parent 984d62a373
commit 3b2cf7b0bc
2 changed files with 2 additions and 2 deletions

View File

@@ -1124,7 +1124,7 @@ class DiffusionPriorNetwork(nn.Module):
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
if self.self_cond: if self.self_cond:
learned_queries = torch.cat((image_embed, self_cond), dim = -2) learned_queries = torch.cat((self_cond, learned_queries), dim = -2)
tokens = torch.cat(( tokens = torch.cat((
text_encodings, text_encodings,

View File

@@ -1 +1 @@
__version__ = '1.12.0' __version__ = '1.12.1'