Compare commits

...

6 Commits
0.0.6 ... 0.0.7

Author SHA1 Message Date
Phil Wang
7e93b9d3c8 make sure classifier free guidance condition scaling is exposed on DALLE2 forward function 2022-04-13 20:14:28 -07:00
Phil Wang
4c827ba94f typo 2022-04-13 19:01:03 -07:00
Phil Wang
cb3923a90f readme tweak 2022-04-13 18:43:34 -07:00
Phil Wang
cc30676a3f lengthen todo 2022-04-13 18:34:09 -07:00
Phil Wang
c7fb327618 link to x-clip 2022-04-13 18:26:30 -07:00
Phil Wang
14ddbc159c cleanup 2022-04-13 18:24:32 -07:00
3 changed files with 34 additions and 26 deletions

View File

@@ -34,7 +34,7 @@ Once built, images will be saved to the same directory the command is invoked
To train DALLE-2 is a 3 step process, with the training of CLIP being the most important
To train CLIP, you can either use `x-clip` package, or join the LAION discord, where a lot of replication efforts are already underway.
To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already underway.
This repository will demonstrate integration with `x-clip` for starters
@@ -136,12 +136,14 @@ loss.backward()
# then it will learn to generate images based on the CLIP image embeddings
```
Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP fron the first step
Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP from the first step
```python
import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
# get trained CLIP from step one
clip = CLIP(
dim_text = 512,
dim_image = 512,
@@ -199,7 +201,7 @@ dalle2 = DALLE2(
decoder = decoder
)
# send the text as a string if you want to use the simple tokenizer from DALL-E1
# send the text as a string if you want to use the simple tokenizer from DALLE v1
# or you can do it as token ids, if you have your own tokenizer
texts = ['glistening morning dew on a flower petal']
@@ -294,7 +296,10 @@ dalle2 = DALLE2(
decoder = decoder
)
images = dalle2(['cute puppy chasing after a squirrel'])
images = dalle2(
['cute puppy chasing after a squirrel'],
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)
# save your image
```
@@ -317,6 +322,7 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [ ] train on a toy task, offer in colab
- [ ] add attention to unet - apply some personal tricks with efficient attention
- [ ] figure out the big idea behind latent diffusion and what can be ported over
## Citations

View File

@@ -246,15 +246,16 @@ class DiffusionPriorNetwork(nn.Module):
def forward_with_cond_scale(
self,
x,
*,
*args,
cond_scale = 1.,
**kwargs
):
if cond_scale == 1:
return self.forward(x, **kwargs)
logits = self.forward(x, *args, **kwargs)
logits = self.forward(x, **kwargs)
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
@@ -635,15 +636,16 @@ class Unet(nn.Module):
def forward_with_cond_scale(
self,
x,
*,
*args,
cond_scale = 1.,
**kwargs
):
if cond_scale == 1:
return self.forward(x, **kwargs)
logits = self.forward(x, *args, **kwargs)
logits = self.forward(x, **kwargs)
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
@@ -774,8 +776,8 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, image_embed, clip_denoised: bool):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, image_embed = image_embed))
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
if clip_denoised:
x_recon.clamp_(-1., 1.)
@@ -784,31 +786,31 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, image_embed, clip_denoised = True, repeat_noise = False):
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, clip_denoised = clip_denoised)
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, image_embed):
def p_sample_loop(self, shape, image_embed, cond_scale = 1):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed)
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
return img
@torch.no_grad()
def sample(self, image_embed):
def sample(self, image_embed, cond_scale = 1.):
batch_size = image_embed.shape[0]
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed)
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -869,7 +871,8 @@ class DALLE2(nn.Module):
@torch.no_grad()
def forward(
self,
text
text,
cond_scale = 1.
):
device = next(self.parameters()).device
@@ -877,7 +880,6 @@ class DALLE2(nn.Module):
text = [text] if not isinstance(text, (list, tuple)) else text
text = tokenizer.tokenize(text).to(device)
print(text.shape, type(text))
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
images = self.decoder.sample(image_embed)
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
return images

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.5',
version = '0.0.7',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',