Compare commits

...

2 Commits

3 changed files with 3 additions and 3 deletions

View File

@@ -1124,7 +1124,7 @@ class DiffusionPriorNetwork(nn.Module):
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
if self.self_cond:
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
learned_queries = torch.cat((self_cond, learned_queries), dim = -2)
tokens = torch.cat((
text_encodings,

View File

@@ -1 +1 @@
__version__ = '1.12.0'
__version__ = '1.12.2'

View File

@@ -27,7 +27,7 @@ setup(
'accelerate',
'click',
'open-clip-torch>=2.0.0,<3.0.0',
'clip-anytorch>=2.4.0',
'clip-anytorch>=2.5.2',
'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7',
'einops>=0.4',