mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
allow for final l2norm clamping of the sampled image embed
This commit is contained in:
@@ -922,11 +922,12 @@ class DiffusionPrior(nn.Module):
|
|||||||
loss_type = "l2",
|
loss_type = "l2",
|
||||||
predict_x_start = True,
|
predict_x_start = True,
|
||||||
beta_schedule = "cosine",
|
beta_schedule = "cosine",
|
||||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||||
sampling_clamp_l2norm = False,
|
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
|
||||||
|
sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm)
|
||||||
training_clamp_l2norm = False,
|
training_clamp_l2norm = False,
|
||||||
init_image_embed_l2norm = False,
|
init_image_embed_l2norm = False,
|
||||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||||
clip_adapter_overrides = dict()
|
clip_adapter_overrides = dict()
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -963,23 +964,32 @@ class DiffusionPrior(nn.Module):
|
|||||||
self.condition_on_text_encodings = condition_on_text_encodings
|
self.condition_on_text_encodings = condition_on_text_encodings
|
||||||
|
|
||||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||||
|
|
||||||
self.predict_x_start = predict_x_start
|
self.predict_x_start = predict_x_start
|
||||||
|
|
||||||
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||||
|
|
||||||
self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
|
self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
|
||||||
|
|
||||||
# whether to force an l2norm, similar to clipping denoised, when sampling
|
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||||
|
|
||||||
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||||
|
self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm
|
||||||
|
|
||||||
self.training_clamp_l2norm = training_clamp_l2norm
|
self.training_clamp_l2norm = training_clamp_l2norm
|
||||||
self.init_image_embed_l2norm = init_image_embed_l2norm
|
self.init_image_embed_l2norm = init_image_embed_l2norm
|
||||||
|
|
||||||
# device tracker
|
# device tracker
|
||||||
|
|
||||||
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
|
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
return self._dummy.device
|
return self._dummy.device
|
||||||
|
|
||||||
|
def l2norm_clamp_embed(self, image_embed):
|
||||||
|
return l2norm(image_embed) * self.image_embed_scale
|
||||||
|
|
||||||
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
||||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||||
|
|
||||||
@@ -1020,6 +1030,9 @@ class DiffusionPrior(nn.Module):
|
|||||||
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
||||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
||||||
|
|
||||||
|
if self.sampling_final_clamp_l2norm and self.predict_x_start:
|
||||||
|
image_embed = self.l2norm_clamp_embed(image_embed)
|
||||||
|
|
||||||
return image_embed
|
return image_embed
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -1055,7 +1068,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
x_start.clamp_(-1., 1.)
|
x_start.clamp_(-1., 1.)
|
||||||
|
|
||||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||||
x_start = l2norm(x_start) * self.image_embed_scale
|
x_start = self.l2norm_clamp_embed(x_start)
|
||||||
|
|
||||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||||
@@ -1065,6 +1078,9 @@ class DiffusionPrior(nn.Module):
|
|||||||
c1 * noise + \
|
c1 * noise + \
|
||||||
c2 * pred_noise
|
c2 * pred_noise
|
||||||
|
|
||||||
|
if self.predict_x_start and self.sampling_final_clamp_l2norm:
|
||||||
|
image_embed = self.l2norm_clamp_embed(image_embed)
|
||||||
|
|
||||||
return image_embed
|
return image_embed
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -1091,7 +1107,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.predict_x_start and self.training_clamp_l2norm:
|
if self.predict_x_start and self.training_clamp_l2norm:
|
||||||
pred = l2norm(pred) * self.image_embed_scale
|
pred = self.l2norm_clamp_embed(pred)
|
||||||
|
|
||||||
target = noise if not self.predict_x_start else image_embed
|
target = noise if not self.predict_x_start else image_embed
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.19.5'
|
__version__ = '0.19.6'
|
||||||
|
|||||||
Reference in New Issue
Block a user