Compare commits

...

12 Commits
0.0.6 ... 0.0.9

Author SHA1 Message Date
Phil Wang
f2c52d8239 fix bug with classifier free guidance for prior network, even though it seems it may not be used 2022-04-14 09:21:51 -07:00
Phil Wang
97e951221b bring in blur, as it will be used somewhere in the cascading DDPM in the decoder eventually, once i figure it out 2022-04-14 09:16:09 -07:00
Phil Wang
e1b0c140f1 cleanup readme 2022-04-14 08:51:22 -07:00
Phil Wang
5989569a44 link to OpenCLIP effort 2022-04-14 08:31:15 -07:00
Phil Wang
82464d7bd3 per-fect 2022-04-14 08:30:07 -07:00
Phil Wang
7fb3f695d5 offer continuously parameterized time embedding for diffusion prior network, remove a hyperparameter that may trip up people, if not set correctly 2022-04-14 08:28:11 -07:00
Phil Wang
7e93b9d3c8 make sure classifier free guidance condition scaling is exposed on DALLE2 forward function 2022-04-13 20:14:28 -07:00
Phil Wang
4c827ba94f typo 2022-04-13 19:01:03 -07:00
Phil Wang
cb3923a90f readme tweak 2022-04-13 18:43:34 -07:00
Phil Wang
cc30676a3f lengthen todo 2022-04-13 18:34:09 -07:00
Phil Wang
c7fb327618 link to x-clip 2022-04-13 18:26:30 -07:00
Phil Wang
14ddbc159c cleanup 2022-04-13 18:24:32 -07:00
3 changed files with 111 additions and 42 deletions

View File

@@ -22,19 +22,11 @@ For all of you emailing me (there is a lot), the best way to contribute is throu
$ pip install dalle2-pytorch $ pip install dalle2-pytorch
``` ```
## CLI Usage (work in progress) ## Usage
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```
Once built, images will be saved to the same directory the command is invoked
## Training (for deep learning practitioners)
To train DALLE-2 is a 3 step process, with the training of CLIP being the most important To train DALLE-2 is a 3 step process, with the training of CLIP being the most important
To train CLIP, you can either use `x-clip` package, or join the LAION discord, where a lot of replication efforts are already underway. To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already <a href="https://github.com/mlfoundations/open_clip">underway</a>.
This repository will demonstrate integration with `x-clip` for starters This repository will demonstrate integration with `x-clip` for starters
@@ -136,12 +128,14 @@ loss.backward()
# then it will learn to generate images based on the CLIP image embeddings # then it will learn to generate images based on the CLIP image embeddings
``` ```
Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP fron the first step Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP from the first step
```python ```python
import torch import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
# get trained CLIP from step one
clip = CLIP( clip = CLIP(
dim_text = 512, dim_text = 512,
dim_image = 512, dim_image = 512,
@@ -160,7 +154,6 @@ clip = CLIP(
prior_network = DiffusionPriorNetwork( prior_network = DiffusionPriorNetwork(
dim = 512, dim = 512,
num_timesteps = 100,
depth = 6, depth = 6,
dim_head = 64, dim_head = 64,
heads = 8 heads = 8
@@ -199,7 +192,7 @@ dalle2 = DALLE2(
decoder = decoder decoder = decoder
) )
# send the text as a string if you want to use the simple tokenizer from DALL-E1 # send the text as a string if you want to use the simple tokenizer from DALLE v1
# or you can do it as token ids, if you have your own tokenizer # or you can do it as token ids, if you have your own tokenizer
texts = ['glistening morning dew on a flower petal'] texts = ['glistening morning dew on a flower petal']
@@ -249,7 +242,6 @@ loss.backward()
prior_network = DiffusionPriorNetwork( prior_network = DiffusionPriorNetwork(
dim = 512, dim = 512,
num_timesteps = 100,
depth = 6, depth = 6,
dim_head = 64, dim_head = 64,
heads = 8 heads = 8
@@ -294,7 +286,10 @@ dalle2 = DALLE2(
decoder = decoder decoder = decoder
) )
images = dalle2(['cute puppy chasing after a squirrel']) images = dalle2(
['cute puppy chasing after a squirrel'],
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)
# save your image # save your image
``` ```
@@ -303,6 +298,18 @@ Everything in this readme should run without error
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training. For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## CLI Usage (work in progress)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```
Once built, images will be saved to the same directory the command is invoked
## Training wrapper (wip)
Offer training wrappers
## Training CLI (wip) ## Training CLI (wip)
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a> <a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
@@ -317,6 +324,7 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper) - [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
- [ ] add attention to unet - apply some personal tricks with efficient attention - [ ] add attention to unet - apply some personal tricks with efficient attention
- [ ] figure out the big idea behind latent diffusion and what can be ported over
## Citations ## Citations

View File

@@ -7,9 +7,12 @@ import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom from einops_exts.torch import EinopsToAndFrom
from kornia.filters import filter2d
from dalle2_pytorch.tokenizer import tokenizer from dalle2_pytorch.tokenizer import tokenizer
# use x-clip # use x-clip
@@ -124,6 +127,43 @@ class PreNormResidual(nn.Module):
def forward(self, x, **kwargs): def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs) + x return self.fn(self.norm(x), **kwargs) + x
# mlp
class MLP(nn.Module):
def __init__(
self,
dim_in,
dim_out,
*,
expansion_factor = 2.,
depth = 2,
norm = False,
):
super().__init__()
hidden_dim = int(expansion_factor * dim_out)
norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()
layers = [nn.Sequential(
nn.Linear(dim_in, hidden_dim),
nn.SiLU(),
norm_fn()
)]
for _ in range(depth - 1):
layers.append(nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
norm_fn()
))
layers.append(nn.Linear(hidden_dim, dim_out))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x.float())
# feedforward
def FeedForward(dim, mult = 4, dropout = 0.): def FeedForward(dim, mult = 4, dropout = 0.):
inner_dim = int(mult * dim) inner_dim = int(mult * dim)
return nn.Sequential( return nn.Sequential(
@@ -134,6 +174,8 @@ def FeedForward(dim, mult = 4, dropout = 0.):
nn.Linear(inner_dim, dim, bias = False) nn.Linear(inner_dim, dim, bias = False)
) )
# attention
class Attention(nn.Module): class Attention(nn.Module):
def __init__( def __init__(
self, self,
@@ -235,26 +277,26 @@ class DiffusionPriorNetwork(nn.Module):
def __init__( def __init__(
self, self,
dim, dim,
num_timesteps = 1000, num_timesteps = None,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.learned_query = nn.Parameter(torch.randn(dim)) self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs) self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
x, *args,
*,
cond_scale = 1., cond_scale = 1.,
**kwargs **kwargs
): ):
if cond_scale == 1: logits = self.forward(*args, **kwargs)
return self.forward(x, **kwargs)
logits = self.forward(x, **kwargs) if cond_scale == 1:
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs) return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale return null_logits + (logits - null_logits) * cond_scale
def forward( def forward(
@@ -274,8 +316,15 @@ class DiffusionPriorNetwork(nn.Module):
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d') text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
if exists(mask): if exists(mask):
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query all_masked_out = mask.any(dim = -1)
mask = torch.cat((mask, rearrange(all_masked_out, 'b -> b 1')), dim = 1)
if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps) time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d') time_embed = rearrange(time_embed, 'b d -> b 1 d')
@@ -634,16 +683,16 @@ class Unet(nn.Module):
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
x, *args,
*,
cond_scale = 1., cond_scale = 1.,
**kwargs **kwargs
): ):
if cond_scale == 1: logits = self.forward(*args, **kwargs)
return self.forward(x, **kwargs)
logits = self.forward(x, **kwargs) if cond_scale == 1:
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs) return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale return null_logits + (logits - null_logits) * cond_scale
def forward( def forward(
@@ -691,6 +740,17 @@ class Unet(nn.Module):
return self.final_conv(x) return self.final_conv(x)
class Blur(nn.Module):
def __init__(self):
super().__init__()
filt = torch.Tensor([1, 2, 1])
self.register_buffer('filt', filt)
def forward(self, x):
filt = self.filt
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
return filter2d(x, filt, normalized = True)
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(
self, self,
@@ -774,8 +834,8 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, image_embed, clip_denoised: bool): def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, image_embed = image_embed)) x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
if clip_denoised: if clip_denoised:
x_recon.clamp_(-1., 1.) x_recon.clamp_(-1., 1.)
@@ -784,31 +844,31 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.no_grad()
def p_sample(self, x, t, image_embed, clip_denoised = True, repeat_noise = False): def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, clip_denoised = clip_denoised) model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise) noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad() @torch.no_grad()
def p_sample_loop(self, shape, image_embed): def p_sample_loop(self, shape, image_embed, cond_scale = 1):
device = self.betas.device device = self.betas.device
b = shape[0] b = shape[0]
img = torch.randn(shape, device=device) img = torch.randn(shape, device=device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed) img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
return img return img
@torch.no_grad() @torch.no_grad()
def sample(self, image_embed): def sample(self, image_embed, cond_scale = 1.):
batch_size = image_embed.shape[0] batch_size = image_embed.shape[0]
image_size = self.image_size image_size = self.image_size
channels = self.channels channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed) return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
def q_sample(self, x_start, t, noise=None): def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
@@ -869,7 +929,8 @@ class DALLE2(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
text text,
cond_scale = 1.
): ):
device = next(self.parameters()).device device = next(self.parameters()).device
@@ -877,7 +938,6 @@ class DALLE2(nn.Module):
text = [text] if not isinstance(text, (list, tuple)) else text text = [text] if not isinstance(text, (list, tuple)) else text
text = tokenizer.tokenize(text).to(device) text = tokenizer.tokenize(text).to(device)
print(text.shape, type(text))
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples) image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
images = self.decoder.sample(image_embed) images = self.decoder.sample(image_embed, cond_scale = cond_scale)
return images return images

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.5', version = '0.0.9',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',
@@ -25,6 +25,7 @@ setup(
'click', 'click',
'einops>=0.4', 'einops>=0.4',
'einops-exts>=0.0.3', 'einops-exts>=0.0.3',
'kornia>=0.5.4',
'pillow', 'pillow',
'torch>=1.10', 'torch>=1.10',
'torchvision', 'torchvision',