mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 19:44:26 +01:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7fb3f695d5 | ||
|
|
7e93b9d3c8 | ||
|
|
4c827ba94f | ||
|
|
cb3923a90f | ||
|
|
cc30676a3f | ||
|
|
c7fb327618 | ||
|
|
14ddbc159c | ||
|
|
0692f1699f | ||
|
|
26c4534bc3 |
22
README.md
22
README.md
@@ -34,7 +34,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
To train DALLE-2 is a 3 step process, with the training of CLIP being the most important
|
||||
|
||||
To train CLIP, you can either use `x-clip` package, or join the LAION discord, where a lot of replication efforts are already underway.
|
||||
To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already underway.
|
||||
|
||||
This repository will demonstrate integration with `x-clip` for starters
|
||||
|
||||
@@ -136,12 +136,14 @@ loss.backward()
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
```
|
||||
|
||||
Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP fron the first step
|
||||
Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP from the first step
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
|
||||
|
||||
# get trained CLIP from step one
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
@@ -160,7 +162,6 @@ clip = CLIP(
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
num_timesteps = 100,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
@@ -199,7 +200,7 @@ dalle2 = DALLE2(
|
||||
decoder = decoder
|
||||
)
|
||||
|
||||
# send the text as a string if you want to use the simple tokenizer from DALL-E1
|
||||
# send the text as a string if you want to use the simple tokenizer from DALLE v1
|
||||
# or you can do it as token ids, if you have your own tokenizer
|
||||
|
||||
texts = ['glistening morning dew on a flower petal']
|
||||
@@ -214,8 +215,6 @@ Let's see the whole script below
|
||||
import torch
|
||||
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP
|
||||
|
||||
import torch
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
@@ -251,7 +250,6 @@ loss.backward()
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
num_timesteps = 100,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
@@ -296,13 +294,18 @@ dalle2 = DALLE2(
|
||||
decoder = decoder
|
||||
)
|
||||
|
||||
images = dalle2(['cute puppy chasing after a squirrel'])
|
||||
images = dalle2(
|
||||
['cute puppy chasing after a squirrel'],
|
||||
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
|
||||
)
|
||||
|
||||
# save your image
|
||||
```
|
||||
|
||||
Everything in this readme should run without error
|
||||
|
||||
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
|
||||
|
||||
## Training CLI (wip)
|
||||
|
||||
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
|
||||
@@ -317,6 +320,7 @@ Everything in this readme should run without error
|
||||
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] add attention to unet - apply some personal tricks with efficient attention
|
||||
- [ ] figure out the big idea behind latent diffusion and what can be ported over
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -364,3 +368,5 @@ Everything in this readme should run without error
|
||||
primaryClass = {cs.LG}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
|
||||
@@ -124,6 +125,43 @@ class PreNormResidual(nn.Module):
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs) + x
|
||||
|
||||
# mlp
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
*,
|
||||
expansion_factor = 2.,
|
||||
depth = 2,
|
||||
norm = False,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(expansion_factor * dim_out)
|
||||
norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()
|
||||
|
||||
layers = [nn.Sequential(
|
||||
nn.Linear(dim_in, hidden_dim),
|
||||
nn.SiLU(),
|
||||
norm_fn()
|
||||
)]
|
||||
|
||||
for _ in range(depth - 1):
|
||||
layers.append(nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
norm_fn()
|
||||
))
|
||||
|
||||
layers.append(nn.Linear(hidden_dim, dim_out))
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x.float())
|
||||
|
||||
# feedforward
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
inner_dim = int(mult * dim)
|
||||
return nn.Sequential(
|
||||
@@ -134,6 +172,8 @@ def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
nn.Linear(inner_dim, dim, bias = False)
|
||||
)
|
||||
|
||||
# attention
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -235,26 +275,27 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_timesteps = 1000,
|
||||
num_timesteps = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
x,
|
||||
*,
|
||||
*args,
|
||||
cond_scale = 1.,
|
||||
**kwargs
|
||||
):
|
||||
if cond_scale == 1:
|
||||
return self.forward(x, **kwargs)
|
||||
logits = self.forward(x, *args, **kwargs)
|
||||
|
||||
logits = self.forward(x, **kwargs)
|
||||
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
|
||||
if cond_scale == 1:
|
||||
return logits
|
||||
|
||||
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
|
||||
return null_logits + (logits - null_logits) * cond_scale
|
||||
|
||||
def forward(
|
||||
@@ -635,15 +676,16 @@ class Unet(nn.Module):
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
x,
|
||||
*,
|
||||
*args,
|
||||
cond_scale = 1.,
|
||||
**kwargs
|
||||
):
|
||||
if cond_scale == 1:
|
||||
return self.forward(x, **kwargs)
|
||||
logits = self.forward(x, *args, **kwargs)
|
||||
|
||||
logits = self.forward(x, **kwargs)
|
||||
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
|
||||
if cond_scale == 1:
|
||||
return logits
|
||||
|
||||
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
|
||||
return null_logits + (logits - null_logits) * cond_scale
|
||||
|
||||
def forward(
|
||||
@@ -774,8 +816,8 @@ class Decoder(nn.Module):
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, t, image_embed, clip_denoised: bool):
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, image_embed = image_embed))
|
||||
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
|
||||
|
||||
if clip_denoised:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
@@ -784,31 +826,31 @@ class Decoder(nn.Module):
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, image_embed, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, clip_denoised = clip_denoised)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, shape, image_embed):
|
||||
def p_sample_loop(self, shape, image_embed, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device=device)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed)
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, image_embed):
|
||||
def sample(self, image_embed, cond_scale = 1.):
|
||||
batch_size = image_embed.shape[0]
|
||||
image_size = self.image_size
|
||||
channels = self.channels
|
||||
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed)
|
||||
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
@@ -869,7 +911,8 @@ class DALLE2(nn.Module):
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
text
|
||||
text,
|
||||
cond_scale = 1.
|
||||
):
|
||||
device = next(self.parameters()).device
|
||||
|
||||
@@ -877,7 +920,6 @@ class DALLE2(nn.Module):
|
||||
text = [text] if not isinstance(text, (list, tuple)) else text
|
||||
text = tokenizer.tokenize(text).to(device)
|
||||
|
||||
print(text.shape, type(text))
|
||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||
images = self.decoder.sample(image_embed)
|
||||
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
|
||||
return images
|
||||
|
||||
Reference in New Issue
Block a user