mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
allow for final l2norm clamping of the sampled image embed
This commit is contained in:
@@ -923,7 +923,8 @@ class DiffusionPrior(nn.Module):
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
sampling_clamp_l2norm = False,
|
||||
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
|
||||
sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm)
|
||||
training_clamp_l2norm = False,
|
||||
init_image_embed_l2norm = False,
|
||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
@@ -963,23 +964,32 @@ class DiffusionPrior(nn.Module):
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
|
||||
self.predict_x_start = predict_x_start
|
||||
|
||||
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
|
||||
self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
|
||||
|
||||
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||
|
||||
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||
self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm
|
||||
|
||||
self.training_clamp_l2norm = training_clamp_l2norm
|
||||
self.init_image_embed_l2norm = init_image_embed_l2norm
|
||||
|
||||
# device tracker
|
||||
|
||||
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._dummy.device
|
||||
|
||||
def l2norm_clamp_embed(self, image_embed):
|
||||
return l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
@@ -1020,6 +1030,9 @@ class DiffusionPrior(nn.Module):
|
||||
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
||||
|
||||
if self.sampling_final_clamp_l2norm and self.predict_x_start:
|
||||
image_embed = self.l2norm_clamp_embed(image_embed)
|
||||
|
||||
return image_embed
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -1055,7 +1068,7 @@ class DiffusionPrior(nn.Module):
|
||||
x_start.clamp_(-1., 1.)
|
||||
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_start = l2norm(x_start) * self.image_embed_scale
|
||||
x_start = self.l2norm_clamp_embed(x_start)
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
@@ -1065,6 +1078,9 @@ class DiffusionPrior(nn.Module):
|
||||
c1 * noise + \
|
||||
c2 * pred_noise
|
||||
|
||||
if self.predict_x_start and self.sampling_final_clamp_l2norm:
|
||||
image_embed = self.l2norm_clamp_embed(image_embed)
|
||||
|
||||
return image_embed
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -1091,7 +1107,7 @@ class DiffusionPrior(nn.Module):
|
||||
)
|
||||
|
||||
if self.predict_x_start and self.training_clamp_l2norm:
|
||||
pred = l2norm(pred) * self.image_embed_scale
|
||||
pred = self.l2norm_clamp_embed(pred)
|
||||
|
||||
target = noise if not self.predict_x_start else image_embed
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '0.19.5'
|
||||
__version__ = '0.19.6'
|
||||
|
||||
Reference in New Issue
Block a user