diff --git a/README.md b/README.md index 4ad5e6a..58a6a2c 100644 --- a/README.md +++ b/README.md @@ -296,7 +296,10 @@ dalle2 = DALLE2( decoder = decoder ) -images = dalle2(['cute puppy chasing after a squirrel']) +images = dalle2( + ['cute puppy chasing after a squirrel'], + cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition) +) # save your image ``` diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index a886252..921d0f2 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -246,15 +246,16 @@ class DiffusionPriorNetwork(nn.Module): def forward_with_cond_scale( self, x, - *, + *args, cond_scale = 1., **kwargs ): - if cond_scale == 1: - return self.forward(x, **kwargs) + logits = self.forward(x, *args, **kwargs) - logits = self.forward(x, **kwargs) - null_logits = self.forward(x, cond_drop_prob = 1., **kwargs) + if cond_scale == 1: + return logits + + null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( @@ -635,15 +636,16 @@ class Unet(nn.Module): def forward_with_cond_scale( self, x, - *, + *args, cond_scale = 1., **kwargs ): - if cond_scale == 1: - return self.forward(x, **kwargs) + logits = self.forward(x, *args, **kwargs) - logits = self.forward(x, **kwargs) - null_logits = self.forward(x, cond_drop_prob = 1., **kwargs) + if cond_scale == 1: + return logits + + null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( @@ -774,8 +776,8 @@ class Decoder(nn.Module): posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped - def p_mean_variance(self, x, t, image_embed, clip_denoised: bool): - x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, image_embed = image_embed)) + def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.): + x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale)) if clip_denoised: x_recon.clamp_(-1., 1.) @@ -784,31 +786,31 @@ class Decoder(nn.Module): return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample(self, x, t, image_embed, clip_denoised = True, repeat_noise = False): + def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False): b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, clip_denoised = clip_denoised) + model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised) noise = noise_like(x.shape, device, repeat_noise) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def p_sample_loop(self, shape, image_embed): + def p_sample_loop(self, shape, image_embed, cond_scale = 1): device = self.betas.device b = shape[0] img = torch.randn(shape, device=device) for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): - img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed) + img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale) return img @torch.no_grad() - def sample(self, image_embed): + def sample(self, image_embed, cond_scale = 1.): batch_size = image_embed.shape[0] image_size = self.image_size channels = self.channels - return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed) + return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) @@ -869,7 +871,8 @@ class DALLE2(nn.Module): @torch.no_grad() def forward( self, - text + text, + cond_scale = 1. ): device = next(self.parameters()).device @@ -878,5 +881,5 @@ class DALLE2(nn.Module): text = tokenizer.tokenize(text).to(device) image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples) - images = self.decoder.sample(image_embed) + images = self.decoder.sample(image_embed, cond_scale = cond_scale) return images diff --git a/setup.py b/setup.py index eacc315..c514016 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.6', + version = '0.0.7', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',