mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 18:24:19 +01:00
make sure classifier free guidance condition scaling is exposed on DALLE2 forward function
This commit is contained in:
@@ -296,7 +296,10 @@ dalle2 = DALLE2(
|
|||||||
decoder = decoder
|
decoder = decoder
|
||||||
)
|
)
|
||||||
|
|
||||||
images = dalle2(['cute puppy chasing after a squirrel'])
|
images = dalle2(
|
||||||
|
['cute puppy chasing after a squirrel'],
|
||||||
|
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
|
||||||
|
)
|
||||||
|
|
||||||
# save your image
|
# save your image
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -246,15 +246,16 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
def forward_with_cond_scale(
|
def forward_with_cond_scale(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
*,
|
*args,
|
||||||
cond_scale = 1.,
|
cond_scale = 1.,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if cond_scale == 1:
|
logits = self.forward(x, *args, **kwargs)
|
||||||
return self.forward(x, **kwargs)
|
|
||||||
|
|
||||||
logits = self.forward(x, **kwargs)
|
if cond_scale == 1:
|
||||||
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
|
return logits
|
||||||
|
|
||||||
|
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
|
||||||
return null_logits + (logits - null_logits) * cond_scale
|
return null_logits + (logits - null_logits) * cond_scale
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -635,15 +636,16 @@ class Unet(nn.Module):
|
|||||||
def forward_with_cond_scale(
|
def forward_with_cond_scale(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
*,
|
*args,
|
||||||
cond_scale = 1.,
|
cond_scale = 1.,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if cond_scale == 1:
|
logits = self.forward(x, *args, **kwargs)
|
||||||
return self.forward(x, **kwargs)
|
|
||||||
|
|
||||||
logits = self.forward(x, **kwargs)
|
if cond_scale == 1:
|
||||||
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
|
return logits
|
||||||
|
|
||||||
|
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
|
||||||
return null_logits + (logits - null_logits) * cond_scale
|
return null_logits + (logits - null_logits) * cond_scale
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -774,8 +776,8 @@ class Decoder(nn.Module):
|
|||||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||||
|
|
||||||
def p_mean_variance(self, x, t, image_embed, clip_denoised: bool):
|
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
|
||||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, image_embed = image_embed))
|
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
|
||||||
|
|
||||||
if clip_denoised:
|
if clip_denoised:
|
||||||
x_recon.clamp_(-1., 1.)
|
x_recon.clamp_(-1., 1.)
|
||||||
@@ -784,31 +786,31 @@ class Decoder(nn.Module):
|
|||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(self, x, t, image_embed, clip_denoised = True, repeat_noise = False):
|
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, clip_denoised = clip_denoised)
|
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
|
||||||
noise = noise_like(x.shape, device, repeat_noise)
|
noise = noise_like(x.shape, device, repeat_noise)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_loop(self, shape, image_embed):
|
def p_sample_loop(self, shape, image_embed, cond_scale = 1):
|
||||||
device = self.betas.device
|
device = self.betas.device
|
||||||
|
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
img = torch.randn(shape, device=device)
|
img = torch.randn(shape, device=device)
|
||||||
|
|
||||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed)
|
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample(self, image_embed):
|
def sample(self, image_embed, cond_scale = 1.):
|
||||||
batch_size = image_embed.shape[0]
|
batch_size = image_embed.shape[0]
|
||||||
image_size = self.image_size
|
image_size = self.image_size
|
||||||
channels = self.channels
|
channels = self.channels
|
||||||
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed)
|
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise=None):
|
def q_sample(self, x_start, t, noise=None):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
@@ -869,7 +871,8 @@ class DALLE2(nn.Module):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
text
|
text,
|
||||||
|
cond_scale = 1.
|
||||||
):
|
):
|
||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
|
|
||||||
@@ -878,5 +881,5 @@ class DALLE2(nn.Module):
|
|||||||
text = tokenizer.tokenize(text).to(device)
|
text = tokenizer.tokenize(text).to(device)
|
||||||
|
|
||||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||||
images = self.decoder.sample(image_embed)
|
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
|
||||||
return images
|
return images
|
||||||
|
|||||||
Reference in New Issue
Block a user