allow for final l2norm clamping of the sampled image embed

This commit is contained in:
Phil Wang
2022-07-10 09:44:31 -07:00
parent 4173e88121
commit 7ea314e2f0
2 changed files with 22 additions and 6 deletions

View File

@@ -923,7 +923,8 @@ class DiffusionPrior(nn.Module):
predict_x_start = True, predict_x_start = True,
beta_schedule = "cosine", beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False, sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm)
training_clamp_l2norm = False, training_clamp_l2norm = False,
init_image_embed_l2norm = False, init_image_embed_l2norm = False,
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
@@ -963,23 +964,32 @@ class DiffusionPrior(nn.Module):
self.condition_on_text_encodings = condition_on_text_encodings self.condition_on_text_encodings = condition_on_text_encodings
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
self.predict_x_start = predict_x_start self.predict_x_start = predict_x_start
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 # @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5) self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
# whether to force an l2norm, similar to clipping denoised, when sampling # whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm self.sampling_clamp_l2norm = sampling_clamp_l2norm
self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm
self.training_clamp_l2norm = training_clamp_l2norm self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm self.init_image_embed_l2norm = init_image_embed_l2norm
# device tracker # device tracker
self.register_buffer('_dummy', torch.tensor([True]), persistent = False) self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
@property @property
def device(self): def device(self):
return self._dummy.device return self._dummy.device
def l2norm_clamp_embed(self, image_embed):
return l2norm(image_embed) * self.image_embed_scale
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.): def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
@@ -1020,6 +1030,9 @@ class DiffusionPrior(nn.Module):
times = torch.full((batch,), i, device = device, dtype = torch.long) times = torch.full((batch,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale) image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
if self.sampling_final_clamp_l2norm and self.predict_x_start:
image_embed = self.l2norm_clamp_embed(image_embed)
return image_embed return image_embed
@torch.no_grad() @torch.no_grad()
@@ -1055,7 +1068,7 @@ class DiffusionPrior(nn.Module):
x_start.clamp_(-1., 1.) x_start.clamp_(-1., 1.)
if self.predict_x_start and self.sampling_clamp_l2norm: if self.predict_x_start and self.sampling_clamp_l2norm:
x_start = l2norm(x_start) * self.image_embed_scale x_start = self.l2norm_clamp_embed(x_start)
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
@@ -1065,6 +1078,9 @@ class DiffusionPrior(nn.Module):
c1 * noise + \ c1 * noise + \
c2 * pred_noise c2 * pred_noise
if self.predict_x_start and self.sampling_final_clamp_l2norm:
image_embed = self.l2norm_clamp_embed(image_embed)
return image_embed return image_embed
@torch.no_grad() @torch.no_grad()
@@ -1091,7 +1107,7 @@ class DiffusionPrior(nn.Module):
) )
if self.predict_x_start and self.training_clamp_l2norm: if self.predict_x_start and self.training_clamp_l2norm:
pred = l2norm(pred) * self.image_embed_scale pred = self.l2norm_clamp_embed(pred)
target = noise if not self.predict_x_start else image_embed target = noise if not self.predict_x_start else image_embed

View File

@@ -1 +1 @@
__version__ = '0.19.5' __version__ = '0.19.6'