|
|
|
|
@@ -1059,10 +1059,10 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
|
|
|
|
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
|
|
|
|
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
|
|
|
|
|
new_noise = torch.randn_like(image_embed)
|
|
|
|
|
|
|
|
|
|
img = x_start * alpha_next.sqrt() + \
|
|
|
|
|
c1 * noise + \
|
|
|
|
|
c1 * new_noise + \
|
|
|
|
|
c2 * pred_noise
|
|
|
|
|
|
|
|
|
|
return image_embed
|
|
|
|
|
@@ -1537,12 +1537,10 @@ class Unet(nn.Module):
|
|
|
|
|
# text encoding conditioning (optional)
|
|
|
|
|
|
|
|
|
|
self.text_to_cond = None
|
|
|
|
|
self.text_embed_dim = None
|
|
|
|
|
|
|
|
|
|
if cond_on_text_encodings:
|
|
|
|
|
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
|
|
|
|
|
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
|
|
|
|
|
self.text_embed_dim = text_embed_dim
|
|
|
|
|
|
|
|
|
|
# finer control over whether to condition on image embeddings and text encodings
|
|
|
|
|
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
|
|
|
|
@@ -1771,8 +1769,6 @@ class Unet(nn.Module):
|
|
|
|
|
text_tokens = None
|
|
|
|
|
|
|
|
|
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
|
|
|
|
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
|
|
|
|
|
|
|
|
|
|
text_tokens = self.text_to_cond(text_encodings)
|
|
|
|
|
text_tokens = text_tokens[:, :self.max_text_len]
|
|
|
|
|
|
|
|
|
|
@@ -2047,7 +2043,7 @@ class Decoder(nn.Module):
|
|
|
|
|
self.noise_schedulers = nn.ModuleList([])
|
|
|
|
|
|
|
|
|
|
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
|
|
|
|
|
assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
|
|
|
|
|
assert sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
|
|
|
|
|
|
|
|
|
|
noise_scheduler = NoiseScheduler(
|
|
|
|
|
beta_schedule = unet_beta_schedule,
|
|
|
|
|
@@ -2275,10 +2271,9 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
|
|
|
|
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
|
|
|
|
noise = torch.randn_like(img) if time_next > 0 else 0.
|
|
|
|
|
|
|
|
|
|
img = x_start * alpha_next.sqrt() + \
|
|
|
|
|
c1 * noise + \
|
|
|
|
|
c1 * torch.randn_like(img) + \
|
|
|
|
|
c2 * pred_noise
|
|
|
|
|
|
|
|
|
|
img = self.unnormalize_img(img)
|
|
|
|
|
|