Compare commits

...

2 Commits

2 changed files with 3 additions and 3 deletions

View File

@@ -1124,7 +1124,7 @@ class DiffusionPriorNetwork(nn.Module):
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
if self.self_cond:
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
learned_queries = torch.cat((self_cond, learned_queries), dim = -2)
tokens = torch.cat((
text_encodings,
@@ -2496,7 +2496,7 @@ class Decoder(nn.Module):
dynamic_thres_percentile = 0.95,
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1,
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict
ddim_sampling_eta = 0. # can be set to 0. for deterministic sampling afaict
):
super().__init__()

View File

@@ -1 +1 @@
__version__ = '1.11.4'
__version__ = '1.12.1'