make sure classifier free guidance condition scaling is exposed on DALLE2 forward function

This commit is contained in:
Phil Wang
2022-04-13 20:14:20 -07:00
parent 4c827ba94f
commit 7e93b9d3c8
3 changed files with 28 additions and 22 deletions

View File

@@ -296,7 +296,10 @@ dalle2 = DALLE2(
decoder = decoder decoder = decoder
) )
images = dalle2(['cute puppy chasing after a squirrel']) images = dalle2(
['cute puppy chasing after a squirrel'],
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)
# save your image # save your image
``` ```

View File

@@ -246,15 +246,16 @@ class DiffusionPriorNetwork(nn.Module):
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
x, x,
*, *args,
cond_scale = 1., cond_scale = 1.,
**kwargs **kwargs
): ):
if cond_scale == 1: logits = self.forward(x, *args, **kwargs)
return self.forward(x, **kwargs)
logits = self.forward(x, **kwargs) if cond_scale == 1:
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs) return logits
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale return null_logits + (logits - null_logits) * cond_scale
def forward( def forward(
@@ -635,15 +636,16 @@ class Unet(nn.Module):
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
x, x,
*, *args,
cond_scale = 1., cond_scale = 1.,
**kwargs **kwargs
): ):
if cond_scale == 1: logits = self.forward(x, *args, **kwargs)
return self.forward(x, **kwargs)
logits = self.forward(x, **kwargs) if cond_scale == 1:
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs) return logits
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale return null_logits + (logits - null_logits) * cond_scale
def forward( def forward(
@@ -774,8 +776,8 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, image_embed, clip_denoised: bool): def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, image_embed = image_embed)) x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
if clip_denoised: if clip_denoised:
x_recon.clamp_(-1., 1.) x_recon.clamp_(-1., 1.)
@@ -784,31 +786,31 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.no_grad()
def p_sample(self, x, t, image_embed, clip_denoised = True, repeat_noise = False): def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, clip_denoised = clip_denoised) model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise) noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad() @torch.no_grad()
def p_sample_loop(self, shape, image_embed): def p_sample_loop(self, shape, image_embed, cond_scale = 1):
device = self.betas.device device = self.betas.device
b = shape[0] b = shape[0]
img = torch.randn(shape, device=device) img = torch.randn(shape, device=device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed) img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
return img return img
@torch.no_grad() @torch.no_grad()
def sample(self, image_embed): def sample(self, image_embed, cond_scale = 1.):
batch_size = image_embed.shape[0] batch_size = image_embed.shape[0]
image_size = self.image_size image_size = self.image_size
channels = self.channels channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed) return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
def q_sample(self, x_start, t, noise=None): def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
@@ -869,7 +871,8 @@ class DALLE2(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
text text,
cond_scale = 1.
): ):
device = next(self.parameters()).device device = next(self.parameters()).device
@@ -878,5 +881,5 @@ class DALLE2(nn.Module):
text = tokenizer.tokenize(text).to(device) text = tokenizer.tokenize(text).to(device)
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples) image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
images = self.decoder.sample(image_embed) images = self.decoder.sample(image_embed, cond_scale = cond_scale)
return images return images

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.6', version = '0.0.7',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',