mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
make sure classifier free guidance is used only if conditional dropout is present on the DiffusionPrior and Decoder classes. also make sure prior can have a different conditional scale than decoder
This commit is contained in:
@@ -901,6 +901,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
self.channels = default(image_channels, lambda: clip.image_channels)
|
self.channels = default(image_channels, lambda: clip.image_channels)
|
||||||
|
|
||||||
self.cond_drop_prob = cond_drop_prob
|
self.cond_drop_prob = cond_drop_prob
|
||||||
|
self.can_classifier_guidance = cond_drop_prob > 0.
|
||||||
self.condition_on_text_encodings = condition_on_text_encodings
|
self.condition_on_text_encodings = condition_on_text_encodings
|
||||||
|
|
||||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||||
@@ -914,8 +915,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
self.training_clamp_l2norm = training_clamp_l2norm
|
self.training_clamp_l2norm = training_clamp_l2norm
|
||||||
self.init_image_embed_l2norm = init_image_embed_l2norm
|
self.init_image_embed_l2norm = init_image_embed_l2norm
|
||||||
|
|
||||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
||||||
pred = self.net(x, t, **text_cond)
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||||
|
|
||||||
|
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
|
||||||
|
|
||||||
if self.predict_x_start:
|
if self.predict_x_start:
|
||||||
x_recon = pred
|
x_recon = pred
|
||||||
@@ -934,16 +937,16 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False):
|
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised)
|
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||||
noise = noise_like(x.shape, device, repeat_noise)
|
noise = noise_like(x.shape, device, repeat_noise)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def p_sample_loop(self, shape, text_cond):
|
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
|
||||||
device = self.betas.device
|
device = self.betas.device
|
||||||
|
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
@@ -954,7 +957,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||||
times = torch.full((b,), i, device = device, dtype = torch.long)
|
times = torch.full((b,), i, device = device, dtype = torch.long)
|
||||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond)
|
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
||||||
|
|
||||||
return image_embed
|
return image_embed
|
||||||
|
|
||||||
@@ -980,19 +983,19 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@eval_decorator
|
@eval_decorator
|
||||||
def sample_batch_size(self, batch_size, text_cond):
|
def sample_batch_size(self, batch_size, text_cond, cond_scale = 1.):
|
||||||
device = self.betas.device
|
device = self.betas.device
|
||||||
shape = (batch_size, self.image_embed_dim)
|
shape = (batch_size, self.image_embed_dim)
|
||||||
|
|
||||||
img = torch.randn(shape, device = device)
|
img = torch.randn(shape, device = device)
|
||||||
|
|
||||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||||
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@eval_decorator
|
@eval_decorator
|
||||||
def sample(self, text, num_samples_per_batch = 2):
|
def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.):
|
||||||
# in the paper, what they did was
|
# in the paper, what they did was
|
||||||
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
|
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
|
||||||
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
|
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
|
||||||
@@ -1007,7 +1010,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
if self.condition_on_text_encodings:
|
if self.condition_on_text_encodings:
|
||||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||||
|
|
||||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale)
|
||||||
|
|
||||||
# retrieve original unscaled image embed
|
# retrieve original unscaled image embed
|
||||||
|
|
||||||
@@ -1793,6 +1796,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
self.image_cond_drop_prob = image_cond_drop_prob
|
self.image_cond_drop_prob = image_cond_drop_prob
|
||||||
self.text_cond_drop_prob = text_cond_drop_prob
|
self.text_cond_drop_prob = text_cond_drop_prob
|
||||||
|
self.can_classifier_guidance = image_cond_drop_prob > 0. or text_cond_drop_prob > 0.
|
||||||
|
|
||||||
# whether to clip when sampling
|
# whether to clip when sampling
|
||||||
|
|
||||||
@@ -1819,6 +1823,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
unet.cpu()
|
unet.cpu()
|
||||||
|
|
||||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
||||||
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||||
|
|
||||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
|
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
|
||||||
|
|
||||||
if learned_variance:
|
if learned_variance:
|
||||||
@@ -2094,6 +2100,7 @@ class DALLE2(nn.Module):
|
|||||||
self,
|
self,
|
||||||
text,
|
text,
|
||||||
cond_scale = 1.,
|
cond_scale = 1.,
|
||||||
|
prior_cond_scale = 1.,
|
||||||
return_pil_images = False
|
return_pil_images = False
|
||||||
):
|
):
|
||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
@@ -2103,7 +2110,7 @@ class DALLE2(nn.Module):
|
|||||||
text = [text] if not isinstance(text, (list, tuple)) else text
|
text = [text] if not isinstance(text, (list, tuple)) else text
|
||||||
text = tokenizer.tokenize(text).to(device)
|
text = tokenizer.tokenize(text).to(device)
|
||||||
|
|
||||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
|
||||||
|
|
||||||
text_cond = text if self.decoder_need_text_cond else None
|
text_cond = text if self.decoder_need_text_cond else None
|
||||||
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
|
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
|
||||||
|
|||||||
Reference in New Issue
Block a user