mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
fix all clipping / clamping issues
This commit is contained in:
@@ -736,6 +736,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
sampling_clamp_l2norm = False
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -764,6 +765,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
self.predict_x_start = predict_x_start
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
|
||||
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||
pred = self.net(x, t, **text_cond)
|
||||
|
||||
@@ -777,6 +781,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
if clip_denoised and not self.predict_x_start:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_recon = l2norm(x_recon)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@@ -1232,6 +1239,7 @@ class Unet(nn.Module):
|
||||
text_tokens = None
|
||||
|
||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||
text_tokens = self.text_to_cond(text_encodings)
|
||||
text_tokens = text_tokens[:, :self.max_text_len]
|
||||
|
||||
text_tokens_len = text_tokens.shape[1]
|
||||
@@ -1244,9 +1252,9 @@ class Unet(nn.Module):
|
||||
if remainder > 0:
|
||||
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
||||
|
||||
text_keep_mask &= text_mask
|
||||
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
||||
text_keep_mask = text_mask & text_keep_mask
|
||||
|
||||
text_tokens = self.text_to_cond(text_encodings)
|
||||
text_tokens = torch.where(
|
||||
text_keep_mask,
|
||||
text_tokens,
|
||||
@@ -1350,6 +1358,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||
clip_denoised = True,
|
||||
clip_x_start = True
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -1426,6 +1436,11 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.image_cond_drop_prob = image_cond_drop_prob
|
||||
self.text_cond_drop_prob = text_cond_drop_prob
|
||||
|
||||
# whether to clip when sampling
|
||||
|
||||
self.clip_denoised = clip_denoised
|
||||
self.clip_x_start = clip_x_start
|
||||
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
index = unet_number - 1
|
||||
@@ -1459,7 +1474,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
else:
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised and not predict_x_start:
|
||||
if clip_denoised:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
@@ -1475,7 +1490,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
@@ -1491,7 +1506,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
text_mask = text_mask,
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
predict_x_start = predict_x_start
|
||||
predict_x_start = predict_x_start,
|
||||
clip_denoised = clip_denoised
|
||||
)
|
||||
|
||||
return img
|
||||
@@ -1542,6 +1558,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
if unet.lowres_cond:
|
||||
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
|
||||
|
||||
is_latent_diffusion = isinstance(vae, VQGanVAE)
|
||||
image_size = vae.get_encoded_fmap_size(image_size)
|
||||
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
||||
|
||||
@@ -1556,6 +1573,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
text_mask = text_mask,
|
||||
cond_scale = cond_scale,
|
||||
predict_x_start = predict_x_start,
|
||||
clip_denoised = not is_latent_diffusion,
|
||||
lowres_cond_img = lowres_cond_img
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user