mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a598820012 | ||
|
|
4878762627 |
@@ -1059,10 +1059,10 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
new_noise = torch.randn_like(image_embed)
|
||||
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
|
||||
|
||||
img = x_start * alpha_next.sqrt() + \
|
||||
c1 * new_noise + \
|
||||
c1 * noise + \
|
||||
c2 * pred_noise
|
||||
|
||||
return image_embed
|
||||
@@ -2047,7 +2047,7 @@ class Decoder(nn.Module):
|
||||
self.noise_schedulers = nn.ModuleList([])
|
||||
|
||||
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
|
||||
assert sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
|
||||
assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
|
||||
|
||||
noise_scheduler = NoiseScheduler(
|
||||
beta_schedule = unet_beta_schedule,
|
||||
@@ -2275,9 +2275,10 @@ class Decoder(nn.Module):
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
noise = torch.randn_like(img) if time_next > 0 else 0.
|
||||
|
||||
img = x_start * alpha_next.sqrt() + \
|
||||
c1 * torch.randn_like(img) + \
|
||||
c1 * noise + \
|
||||
c2 * pred_noise
|
||||
|
||||
img = self.unnormalize_img(img)
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.19.2'
|
||||
__version__ = '0.19.4'
|
||||
|
||||
Reference in New Issue
Block a user