Compare commits

..

3 Commits

3 changed files with 5 additions and 11 deletions

View File

@@ -1124,7 +1124,7 @@ class DiffusionPriorNetwork(nn.Module):
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
if self.self_cond: if self.self_cond:
learned_queries = torch.cat((image_embed, self_cond), dim = -2) learned_queries = torch.cat((self_cond, learned_queries), dim = -2)
tokens = torch.cat(( tokens = torch.cat((
text_encodings, text_encodings,
@@ -1334,10 +1334,7 @@ class DiffusionPrior(nn.Module):
# predict noise # predict noise
if self.predict_x_start or self.predict_v: pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
else:
pred_noise = pred
if time_next < 0: if time_next < 0:
image_embed = x_start image_embed = x_start
@@ -2975,10 +2972,7 @@ class Decoder(nn.Module):
# predict noise # predict noise
if predict_x_start or predict_v: pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
else:
pred_noise = pred
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()

View File

@@ -1 +1 @@
__version__ = '1.12.0' __version__ = '1.12.3'

View File

@@ -27,7 +27,7 @@ setup(
'accelerate', 'accelerate',
'click', 'click',
'open-clip-torch>=2.0.0,<3.0.0', 'open-clip-torch>=2.0.0,<3.0.0',
'clip-anytorch>=2.4.0', 'clip-anytorch>=2.5.2',
'coca-pytorch>=0.0.5', 'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7', 'ema-pytorch>=0.0.7',
'einops>=0.4', 'einops>=0.4',