mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 18:24:19 +01:00
make sure classifier free guidance condition scaling is exposed on DALLE2 forward function
This commit is contained in:
@@ -296,7 +296,10 @@ dalle2 = DALLE2(
|
||||
decoder = decoder
|
||||
)
|
||||
|
||||
images = dalle2(['cute puppy chasing after a squirrel'])
|
||||
images = dalle2(
|
||||
['cute puppy chasing after a squirrel'],
|
||||
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
|
||||
)
|
||||
|
||||
# save your image
|
||||
```
|
||||
|
||||
@@ -246,15 +246,16 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
x,
|
||||
*,
|
||||
*args,
|
||||
cond_scale = 1.,
|
||||
**kwargs
|
||||
):
|
||||
if cond_scale == 1:
|
||||
return self.forward(x, **kwargs)
|
||||
logits = self.forward(x, *args, **kwargs)
|
||||
|
||||
logits = self.forward(x, **kwargs)
|
||||
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
|
||||
if cond_scale == 1:
|
||||
return logits
|
||||
|
||||
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
|
||||
return null_logits + (logits - null_logits) * cond_scale
|
||||
|
||||
def forward(
|
||||
@@ -635,15 +636,16 @@ class Unet(nn.Module):
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
x,
|
||||
*,
|
||||
*args,
|
||||
cond_scale = 1.,
|
||||
**kwargs
|
||||
):
|
||||
if cond_scale == 1:
|
||||
return self.forward(x, **kwargs)
|
||||
logits = self.forward(x, *args, **kwargs)
|
||||
|
||||
logits = self.forward(x, **kwargs)
|
||||
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
|
||||
if cond_scale == 1:
|
||||
return logits
|
||||
|
||||
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
|
||||
return null_logits + (logits - null_logits) * cond_scale
|
||||
|
||||
def forward(
|
||||
@@ -774,8 +776,8 @@ class Decoder(nn.Module):
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, t, image_embed, clip_denoised: bool):
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, image_embed = image_embed))
|
||||
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
|
||||
|
||||
if clip_denoised:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
@@ -784,31 +786,31 @@ class Decoder(nn.Module):
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, image_embed, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, clip_denoised = clip_denoised)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, shape, image_embed):
|
||||
def p_sample_loop(self, shape, image_embed, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device=device)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed)
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, image_embed):
|
||||
def sample(self, image_embed, cond_scale = 1.):
|
||||
batch_size = image_embed.shape[0]
|
||||
image_size = self.image_size
|
||||
channels = self.channels
|
||||
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed)
|
||||
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
@@ -869,7 +871,8 @@ class DALLE2(nn.Module):
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
text
|
||||
text,
|
||||
cond_scale = 1.
|
||||
):
|
||||
device = next(self.parameters()).device
|
||||
|
||||
@@ -878,5 +881,5 @@ class DALLE2(nn.Module):
|
||||
text = tokenizer.tokenize(text).to(device)
|
||||
|
||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||
images = self.decoder.sample(image_embed)
|
||||
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
|
||||
return images
|
||||
|
||||
Reference in New Issue
Block a user