mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
fix all clipping / clamping issues
This commit is contained in:
@@ -736,6 +736,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
predict_x_start = True,
|
predict_x_start = True,
|
||||||
beta_schedule = "cosine",
|
beta_schedule = "cosine",
|
||||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||||
|
sampling_clamp_l2norm = False
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
beta_schedule = beta_schedule,
|
beta_schedule = beta_schedule,
|
||||||
@@ -764,6 +765,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
self.predict_x_start = predict_x_start
|
self.predict_x_start = predict_x_start
|
||||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||||
|
|
||||||
|
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||||
|
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||||
|
|
||||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||||
pred = self.net(x, t, **text_cond)
|
pred = self.net(x, t, **text_cond)
|
||||||
|
|
||||||
@@ -777,6 +781,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
if clip_denoised and not self.predict_x_start:
|
if clip_denoised and not self.predict_x_start:
|
||||||
x_recon.clamp_(-1., 1.)
|
x_recon.clamp_(-1., 1.)
|
||||||
|
|
||||||
|
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||||
|
x_recon = l2norm(x_recon)
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
@@ -1232,6 +1239,7 @@ class Unet(nn.Module):
|
|||||||
text_tokens = None
|
text_tokens = None
|
||||||
|
|
||||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||||
|
text_tokens = self.text_to_cond(text_encodings)
|
||||||
text_tokens = text_tokens[:, :self.max_text_len]
|
text_tokens = text_tokens[:, :self.max_text_len]
|
||||||
|
|
||||||
text_tokens_len = text_tokens.shape[1]
|
text_tokens_len = text_tokens.shape[1]
|
||||||
@@ -1244,9 +1252,9 @@ class Unet(nn.Module):
|
|||||||
if remainder > 0:
|
if remainder > 0:
|
||||||
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
||||||
|
|
||||||
text_keep_mask &= text_mask
|
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
||||||
|
text_keep_mask = text_mask & text_keep_mask
|
||||||
|
|
||||||
text_tokens = self.text_to_cond(text_encodings)
|
|
||||||
text_tokens = torch.where(
|
text_tokens = torch.where(
|
||||||
text_keep_mask,
|
text_keep_mask,
|
||||||
text_tokens,
|
text_tokens,
|
||||||
@@ -1350,6 +1358,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||||
|
clip_denoised = True,
|
||||||
|
clip_x_start = True
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
beta_schedule = beta_schedule,
|
beta_schedule = beta_schedule,
|
||||||
@@ -1426,6 +1436,11 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self.image_cond_drop_prob = image_cond_drop_prob
|
self.image_cond_drop_prob = image_cond_drop_prob
|
||||||
self.text_cond_drop_prob = text_cond_drop_prob
|
self.text_cond_drop_prob = text_cond_drop_prob
|
||||||
|
|
||||||
|
# whether to clip when sampling
|
||||||
|
|
||||||
|
self.clip_denoised = clip_denoised
|
||||||
|
self.clip_x_start = clip_x_start
|
||||||
|
|
||||||
def get_unet(self, unet_number):
|
def get_unet(self, unet_number):
|
||||||
assert 0 < unet_number <= len(self.unets)
|
assert 0 < unet_number <= len(self.unets)
|
||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
@@ -1459,7 +1474,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
else:
|
else:
|
||||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||||
|
|
||||||
if clip_denoised and not predict_x_start:
|
if clip_denoised:
|
||||||
x_recon.clamp_(-1., 1.)
|
x_recon.clamp_(-1., 1.)
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
@@ -1475,7 +1490,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
||||||
device = self.betas.device
|
device = self.betas.device
|
||||||
|
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
@@ -1491,7 +1506,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
text_mask = text_mask,
|
text_mask = text_mask,
|
||||||
cond_scale = cond_scale,
|
cond_scale = cond_scale,
|
||||||
lowres_cond_img = lowres_cond_img,
|
lowres_cond_img = lowres_cond_img,
|
||||||
predict_x_start = predict_x_start
|
predict_x_start = predict_x_start,
|
||||||
|
clip_denoised = clip_denoised
|
||||||
)
|
)
|
||||||
|
|
||||||
return img
|
return img
|
||||||
@@ -1542,6 +1558,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
if unet.lowres_cond:
|
if unet.lowres_cond:
|
||||||
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
|
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
|
||||||
|
|
||||||
|
is_latent_diffusion = isinstance(vae, VQGanVAE)
|
||||||
image_size = vae.get_encoded_fmap_size(image_size)
|
image_size = vae.get_encoded_fmap_size(image_size)
|
||||||
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
||||||
|
|
||||||
@@ -1556,6 +1573,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
text_mask = text_mask,
|
text_mask = text_mask,
|
||||||
cond_scale = cond_scale,
|
cond_scale = cond_scale,
|
||||||
predict_x_start = predict_x_start,
|
predict_x_start = predict_x_start,
|
||||||
|
clip_denoised = not is_latent_diffusion,
|
||||||
lowres_cond_img = lowres_cond_img
|
lowres_cond_img = lowres_cond_img
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user