mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
make sure classifier free guidance is used only if conditional dropout is present on the DiffusionPrior and Decoder classes. also make sure prior can have a different conditional scale than decoder
This commit is contained in:
@@ -901,6 +901,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
self.channels = default(image_channels, lambda: clip.image_channels)
|
||||
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
self.can_classifier_guidance = cond_drop_prob > 0.
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
@@ -914,8 +915,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
self.training_clamp_l2norm = training_clamp_l2norm
|
||||
self.init_image_embed_l2norm = init_image_embed_l2norm
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||
pred = self.net(x, t, **text_cond)
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_recon = pred
|
||||
@@ -934,16 +937,16 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.inference_mode()
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.inference_mode()
|
||||
def p_sample_loop(self, shape, text_cond):
|
||||
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
@@ -954,7 +957,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||
times = torch.full((b,), i, device = device, dtype = torch.long)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
||||
|
||||
return image_embed
|
||||
|
||||
@@ -980,19 +983,19 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
@torch.inference_mode()
|
||||
@eval_decorator
|
||||
def sample_batch_size(self, batch_size, text_cond):
|
||||
def sample_batch_size(self, batch_size, text_cond, cond_scale = 1.):
|
||||
device = self.betas.device
|
||||
shape = (batch_size, self.image_embed_dim)
|
||||
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
||||
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)
|
||||
return img
|
||||
|
||||
@torch.inference_mode()
|
||||
@eval_decorator
|
||||
def sample(self, text, num_samples_per_batch = 2):
|
||||
def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.):
|
||||
# in the paper, what they did was
|
||||
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
|
||||
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
|
||||
@@ -1007,7 +1010,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
if self.condition_on_text_encodings:
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale)
|
||||
|
||||
# retrieve original unscaled image embed
|
||||
|
||||
@@ -1793,6 +1796,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
self.image_cond_drop_prob = image_cond_drop_prob
|
||||
self.text_cond_drop_prob = text_cond_drop_prob
|
||||
self.can_classifier_guidance = image_cond_drop_prob > 0. or text_cond_drop_prob > 0.
|
||||
|
||||
# whether to clip when sampling
|
||||
|
||||
@@ -1819,6 +1823,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
unet.cpu()
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
|
||||
|
||||
if learned_variance:
|
||||
@@ -2094,6 +2100,7 @@ class DALLE2(nn.Module):
|
||||
self,
|
||||
text,
|
||||
cond_scale = 1.,
|
||||
prior_cond_scale = 1.,
|
||||
return_pil_images = False
|
||||
):
|
||||
device = next(self.parameters()).device
|
||||
@@ -2103,7 +2110,7 @@ class DALLE2(nn.Module):
|
||||
text = [text] if not isinstance(text, (list, tuple)) else text
|
||||
text = tokenizer.tokenize(text).to(device)
|
||||
|
||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
|
||||
|
||||
text_cond = text if self.decoder_need_text_cond else None
|
||||
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
|
||||
|
||||
Reference in New Issue
Block a user