make sure classifier free guidance is used only if conditional dropout is present on the DiffusionPrior and Decoder classes. also make sure prior can have a different conditional scale than decoder

This commit is contained in:
Phil Wang
2022-05-15 19:09:38 -07:00
parent 36c5079bd7
commit ecf9e8027d
2 changed files with 19 additions and 12 deletions

View File

@@ -901,6 +901,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.channels = default(image_channels, lambda: clip.image_channels) self.channels = default(image_channels, lambda: clip.image_channels)
self.cond_drop_prob = cond_drop_prob self.cond_drop_prob = cond_drop_prob
self.can_classifier_guidance = cond_drop_prob > 0.
self.condition_on_text_encodings = condition_on_text_encodings self.condition_on_text_encodings = condition_on_text_encodings
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
@@ -914,8 +915,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
self.training_clamp_l2norm = training_clamp_l2norm self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm self.init_image_embed_l2norm = init_image_embed_l2norm
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
pred = self.net(x, t, **text_cond) assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
if self.predict_x_start: if self.predict_x_start:
x_recon = pred x_recon = pred
@@ -934,16 +937,16 @@ class DiffusionPrior(BaseGaussianDiffusion):
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.inference_mode() @torch.inference_mode()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False): def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised) model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
noise = noise_like(x.shape, device, repeat_noise) noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.inference_mode() @torch.inference_mode()
def p_sample_loop(self, shape, text_cond): def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
device = self.betas.device device = self.betas.device
b = shape[0] b = shape[0]
@@ -954,7 +957,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
times = torch.full((b,), i, device = device, dtype = torch.long) times = torch.full((b,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond) image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
return image_embed return image_embed
@@ -980,19 +983,19 @@ class DiffusionPrior(BaseGaussianDiffusion):
@torch.inference_mode() @torch.inference_mode()
@eval_decorator @eval_decorator
def sample_batch_size(self, batch_size, text_cond): def sample_batch_size(self, batch_size, text_cond, cond_scale = 1.):
device = self.betas.device device = self.betas.device
shape = (batch_size, self.image_embed_dim) shape = (batch_size, self.image_embed_dim)
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond) img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)
return img return img
@torch.inference_mode() @torch.inference_mode()
@eval_decorator @eval_decorator
def sample(self, text, num_samples_per_batch = 2): def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.):
# in the paper, what they did was # in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP # sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch) text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
@@ -1007,7 +1010,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
if self.condition_on_text_encodings: if self.condition_on_text_encodings:
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale)
# retrieve original unscaled image embed # retrieve original unscaled image embed
@@ -1793,6 +1796,7 @@ class Decoder(BaseGaussianDiffusion):
self.image_cond_drop_prob = image_cond_drop_prob self.image_cond_drop_prob = image_cond_drop_prob
self.text_cond_drop_prob = text_cond_drop_prob self.text_cond_drop_prob = text_cond_drop_prob
self.can_classifier_guidance = image_cond_drop_prob > 0. or text_cond_drop_prob > 0.
# whether to clip when sampling # whether to clip when sampling
@@ -1819,6 +1823,8 @@ class Decoder(BaseGaussianDiffusion):
unet.cpu() unet.cpu()
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None): def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)) pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
if learned_variance: if learned_variance:
@@ -2094,6 +2100,7 @@ class DALLE2(nn.Module):
self, self,
text, text,
cond_scale = 1., cond_scale = 1.,
prior_cond_scale = 1.,
return_pil_images = False return_pil_images = False
): ):
device = next(self.parameters()).device device = next(self.parameters()).device
@@ -2103,7 +2110,7 @@ class DALLE2(nn.Module):
text = [text] if not isinstance(text, (list, tuple)) else text text = [text] if not isinstance(text, (list, tuple)) else text
text = tokenizer.tokenize(text).to(device) text = tokenizer.tokenize(text).to(device)
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples) image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
text_cond = text if self.decoder_need_text_cond else None text_cond = text if self.decoder_need_text_cond else None
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale) images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.37', version = '0.2.38',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',