mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
fix for self conditioning in diffusion prior network https://github.com/lucidrains/DALLE2-pytorch/issues/273
This commit is contained in:
@@ -1124,7 +1124,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||
|
||||
if self.self_cond:
|
||||
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
|
||||
learned_queries = torch.cat((self_cond, learned_queries), dim = -2)
|
||||
|
||||
tokens = torch.cat((
|
||||
text_encodings,
|
||||
|
||||
Reference in New Issue
Block a user