mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 08:44:20 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
848e8a480a | ||
|
|
cc58f75474 | ||
|
|
3b2cf7b0bc |
@@ -1124,7 +1124,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||||
|
|
||||||
if self.self_cond:
|
if self.self_cond:
|
||||||
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
|
learned_queries = torch.cat((self_cond, learned_queries), dim = -2)
|
||||||
|
|
||||||
tokens = torch.cat((
|
tokens = torch.cat((
|
||||||
text_encodings,
|
text_encodings,
|
||||||
@@ -1334,10 +1334,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
# predict noise
|
# predict noise
|
||||||
|
|
||||||
if self.predict_x_start or self.predict_v:
|
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
|
||||||
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
|
|
||||||
else:
|
|
||||||
pred_noise = pred
|
|
||||||
|
|
||||||
if time_next < 0:
|
if time_next < 0:
|
||||||
image_embed = x_start
|
image_embed = x_start
|
||||||
@@ -2975,10 +2972,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
# predict noise
|
# predict noise
|
||||||
|
|
||||||
if predict_x_start or predict_v:
|
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
|
||||||
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
|
|
||||||
else:
|
|
||||||
pred_noise = pred
|
|
||||||
|
|
||||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.12.0'
|
__version__ = '1.12.3'
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -27,7 +27,7 @@ setup(
|
|||||||
'accelerate',
|
'accelerate',
|
||||||
'click',
|
'click',
|
||||||
'open-clip-torch>=2.0.0,<3.0.0',
|
'open-clip-torch>=2.0.0,<3.0.0',
|
||||||
'clip-anytorch>=2.4.0',
|
'clip-anytorch>=2.5.2',
|
||||||
'coca-pytorch>=0.0.5',
|
'coca-pytorch>=0.0.5',
|
||||||
'ema-pytorch>=0.0.7',
|
'ema-pytorch>=0.0.7',
|
||||||
'einops>=0.4',
|
'einops>=0.4',
|
||||||
|
|||||||
Reference in New Issue
Block a user