mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 01:34:19 +01:00
fix everything and make sure it runs end to end, document everything in readme for public
This commit is contained in:
283
README.md
283
README.md
@@ -22,9 +22,7 @@ For all of you emailing me (there is a lot), the best way to contribute is throu
|
||||
$ pip install dalle2-pytorch
|
||||
```
|
||||
|
||||
## Usage (work in progress)
|
||||
|
||||
<a href="https://github.com/lucidrains/big-sleep">template</a>
|
||||
## CLI Usage (work in progress)
|
||||
|
||||
```bash
|
||||
$ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
@@ -32,17 +30,288 @@ $ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
|
||||
Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
## Training (work in progress, will offer both in code and as command-line)
|
||||
## Training (for deep learning practitioners)
|
||||
|
||||
To train DALLE-2 is a 3 step process, with the training of CLIP being the most important
|
||||
|
||||
To train CLIP, you can either use `x-clip` package, or join the LAION discord, where a lot of replication efforts are already underway.
|
||||
|
||||
This repository will demonstrate integration with `x-clip` for starters
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import CLIP
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 1,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 1,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8,
|
||||
use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
|
||||
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
|
||||
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
|
||||
use_visual_ssl = True, # whether to do self supervised learning on iages
|
||||
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
|
||||
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
|
||||
text_ssl_loss_weight = 0.05, # weight for text MLM loss
|
||||
image_ssl_loss_weight = 0.05 # weight for image self-supervised learning loss
|
||||
).cuda()
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
|
||||
# train
|
||||
|
||||
loss = clip(
|
||||
text,
|
||||
images,
|
||||
return_loss = True # needs to be set to True to return contrastive loss
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# do the above with as many texts and images as possible in a loop
|
||||
```
|
||||
|
||||
Then, you will need to train the decoder, which learns to generate images based on the image embedding coming from the trained CLIP above
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder, CLIP
|
||||
|
||||
# trained clip from step 1
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 1,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 1,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8
|
||||
).cuda()
|
||||
|
||||
# unet for the decoder
|
||||
|
||||
unet = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
time_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
# decoder, which contains the unet and clip
|
||||
|
||||
decoder = Decoder(
|
||||
net = unet,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
|
||||
# feed images into decoder
|
||||
|
||||
loss = decoder(images)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many many many many steps
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
```
|
||||
|
||||
Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP fron the first step
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 6,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 6,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8,
|
||||
).cuda()
|
||||
|
||||
# setup prior network, which contains an autoregressive transformer
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
num_timesteps = 100,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
).cuda()
|
||||
|
||||
# diffusion prior network, which contains the CLIP and network (with transformer) above
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
|
||||
# feed text and images into diffusion prior network
|
||||
|
||||
loss = diffusion_prior(text, images)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many many many steps
|
||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||
```
|
||||
|
||||
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer)
|
||||
|
||||
```python
|
||||
from dalle2_pytorch import DALLE2
|
||||
|
||||
dalle2 = DALLE2(
|
||||
prior = diffusion_prior,
|
||||
decoder = decoder
|
||||
)
|
||||
|
||||
# send the text as a string if you want to use the simple tokenizer from DALL-E1
|
||||
# or you can do it as token ids, if you have your own tokenizer
|
||||
|
||||
texts = ['glistening morning dew on a flower petal']
|
||||
images = dalle2(texts) # (1, 3, 256, 256)
|
||||
```
|
||||
|
||||
That's it!
|
||||
|
||||
Let's see the whole script below
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP
|
||||
|
||||
import torch
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 6,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 6,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8
|
||||
).cuda()
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
|
||||
# train
|
||||
|
||||
loss = clip(
|
||||
text,
|
||||
images,
|
||||
return_loss = True
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps ...
|
||||
|
||||
# prior networks (with transformer)
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
num_timesteps = 100,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
).cuda()
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
loss = diffusion_prior(text, images)
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps ...
|
||||
|
||||
# decoder (with unet)
|
||||
|
||||
unet = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
time_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
decoder = Decoder(
|
||||
net = unet,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
loss = decoder(images)
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps
|
||||
|
||||
dalle2 = DALLE2(
|
||||
prior = diffusion_prior,
|
||||
decoder = decoder
|
||||
)
|
||||
|
||||
images = dalle2(['cute puppy chasing after a squirrel'])
|
||||
|
||||
# save your image
|
||||
```
|
||||
|
||||
Everything in this readme should run without error
|
||||
|
||||
## Training CLI (wip)
|
||||
|
||||
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
|
||||
|
||||
Todo
|
||||
|
||||
## Todo
|
||||
|
||||
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
|
||||
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
|
||||
- [ ] make sure it works end to end to produce an output tensor, taking a single gradient step
|
||||
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
|
||||
- [ ] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
|
||||
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
|
||||
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
from x_clip import CLIP
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import tqdm
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange
|
||||
from einops_exts import rearrange_many, repeat_many
|
||||
from einops import rearrange, repeat
|
||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
|
||||
# use x-clip
|
||||
|
||||
@@ -16,7 +22,9 @@ def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
def eval_decorator(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
@@ -27,6 +35,11 @@ def eval_decorator(fn):
|
||||
return out
|
||||
return inner
|
||||
|
||||
def is_list_str(x):
|
||||
if not isinstance(x, (list, tuple)):
|
||||
return False
|
||||
return all([type(el) == str for el in x])
|
||||
|
||||
# for controlling freezing of CLIP
|
||||
|
||||
def set_module_requires_grad_(module, requires_grad):
|
||||
@@ -43,6 +56,11 @@ def freeze_model_and_make_eval_(model):
|
||||
model.eval()
|
||||
freeze_all_layers_(model)
|
||||
|
||||
# tensor helpers
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1)
|
||||
|
||||
# classifier free guidance functions
|
||||
|
||||
def prob_mask_like(shape, prob, device):
|
||||
@@ -91,9 +109,16 @@ class RMSNorm(nn.Module):
|
||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||
return x * inv_norm * self.gamma * self.scale
|
||||
|
||||
class ChanRMSNorm(RMSNorm):
|
||||
def forward(self, x):
|
||||
squared_sum = (x ** 2).sum(dim = 1, keepdim = True)
|
||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
|
||||
|
||||
class PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
@@ -112,8 +137,8 @@ def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
*,
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
@@ -121,6 +146,7 @@ class Attention(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.causal = causal
|
||||
@@ -128,17 +154,17 @@ class Attention(nn.Module):
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
self.to_qkv = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
b, n, device = x.shape[:2], x.device
|
||||
b, n, device = *x.shape[:2], x.device
|
||||
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
|
||||
|
||||
q = rearrange(q, 'b n (h d) -> b h n d')
|
||||
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
||||
|
||||
# add null key / value for classifier free guidance in prior net
|
||||
|
||||
@@ -148,7 +174,7 @@ class Attention(nn.Module):
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
sim = einsum('b h i d, b j d -> b h i j')
|
||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
|
||||
if exists(mask):
|
||||
@@ -157,7 +183,8 @@ class Attention(nn.Module):
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
if self.causal:
|
||||
causal_mask = torch.ones((n, n), dtype = torch.bool, device = device).triu(1)
|
||||
i, j = sim.shape[-2:]
|
||||
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
|
||||
sim = sim.masked_fill(causal_mask, max_neg_value)
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||
@@ -214,7 +241,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
super().__init__()
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(**kwargs)
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
@@ -227,7 +254,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
return self.forward(x, **kwargs)
|
||||
|
||||
logits = self.forward(x, **kwargs)
|
||||
null_logits = self.forward(x, cond_prob_drop = 1., **kwargs)
|
||||
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
|
||||
return null_logits + (logits - null_logits) * cond_scale
|
||||
|
||||
def forward(
|
||||
@@ -248,9 +275,10 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 4), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.time_embeddings(diffusion_timesteps)
|
||||
time_embed = rearrange(time_embed, 'b d -> b 1 d')
|
||||
|
||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||
|
||||
@@ -268,7 +296,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch_size,), cond_prob_drop, device = device)
|
||||
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
|
||||
mask &= rearrange(cond_prob_mask, 'b -> b 1')
|
||||
|
||||
# attend
|
||||
@@ -288,19 +316,20 @@ class DiffusionPrior(nn.Module):
|
||||
*,
|
||||
clip,
|
||||
timesteps = 1000,
|
||||
cond_prob_drop = 0.2,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = 'l1',
|
||||
predict_x0 = True
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
self.clip = clip
|
||||
|
||||
self.net = net
|
||||
self.image_embed_dim = clip.dim_latent
|
||||
self.channels = clip.image_channels
|
||||
self.image_size = clip.image_size
|
||||
self.cond_prob_drop = cond_prob_drop
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
|
||||
self.predict_x0 = predict_x0
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
@@ -389,7 +418,7 @@ class DiffusionPrior(nn.Module):
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, image_embed, text_cond = None, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
@@ -420,18 +449,18 @@ class DiffusionPrior(nn.Module):
|
||||
text_cond = self.get_text_cond(text)
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
text_embeds = text_cond['text_embeds']
|
||||
text_embeds = text_cond['text_embed']
|
||||
|
||||
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||
|
||||
text_image_sims = einsum('b r d, b r d -> b r')
|
||||
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
|
||||
top_sim_indices = text_image_sims.topk(k = 1).indices
|
||||
|
||||
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b d', d = image_embed_dim)
|
||||
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
|
||||
|
||||
top_image_embeds = image_embeds.gather(1, top_sim_indices)
|
||||
return top_image_embeds
|
||||
return rearrange(top_image_embeds, 'b 1 d -> b d')
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
@@ -442,14 +471,14 @@ class DiffusionPrior(nn.Module):
|
||||
)
|
||||
|
||||
def p_losses(self, image_embed, t, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
|
||||
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise)
|
||||
|
||||
x_recon = self.net(
|
||||
image_embed_noisy,
|
||||
t,
|
||||
cond_prob_drop = self.cond_prob_drop,
|
||||
cond_drop_prob = self.cond_drop_prob,
|
||||
**text_cond
|
||||
)
|
||||
|
||||
@@ -472,7 +501,7 @@ class DiffusionPrior(nn.Module):
|
||||
image_embed = self.get_image_embed(image)
|
||||
text_cond = self.get_text_cond(text)
|
||||
|
||||
loss = self.p_losses(x, times, image_embed = image_embed, text_cond = text_cond, *args, **kwargs)
|
||||
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
||||
return loss
|
||||
|
||||
# decoder
|
||||
@@ -519,7 +548,7 @@ class ConvNextBlock(nn.Module):
|
||||
|
||||
inner_dim = int(dim_out * mult)
|
||||
self.net = nn.Sequential(
|
||||
RMSNorm(dim) if norm else nn.Identity(),
|
||||
ChanRMSNorm(dim) if norm else nn.Identity(),
|
||||
nn.Conv2d(dim, inner_dim, 3, padding = 1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
|
||||
@@ -538,21 +567,6 @@ class ConvNextBlock(nn.Module):
|
||||
h = self.net(h)
|
||||
return h + self.res_conv(x)
|
||||
|
||||
class EinopsToAndFrom(nn.Module):
|
||||
def __init__(self, from_einops, to_einops, fn):
|
||||
super().__init__()
|
||||
self.from_einops = from_einops
|
||||
self.to_einops = to_einops
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
shape = x.shape
|
||||
reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))
|
||||
x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
|
||||
x = self.fn(x, **kwargs)
|
||||
x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
|
||||
return x
|
||||
|
||||
class Unet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -597,6 +611,7 @@ class Unet(nn.Module):
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim)))
|
||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
@@ -627,7 +642,7 @@ class Unet(nn.Module):
|
||||
return self.forward(x, **kwargs)
|
||||
|
||||
logits = self.forward(x, **kwargs)
|
||||
null_logits = self.forward(x, cond_prob_drop = 1., **kwargs)
|
||||
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
|
||||
return null_logits + (logits - null_logits) * cond_scale
|
||||
|
||||
def forward(
|
||||
@@ -637,11 +652,12 @@ class Unet(nn.Module):
|
||||
*,
|
||||
image_embed,
|
||||
text_encodings = None,
|
||||
cond_prob_drop = 0.
|
||||
cond_drop_prob = 0.
|
||||
):
|
||||
batch_size, device = x.shape[0], x.device
|
||||
t = self.time_mlp(time)
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch_size,), cond_prob_drop, device = device)
|
||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
||||
|
||||
# mask out image embedding depending on condition dropout
|
||||
# for classifier free guidance
|
||||
@@ -652,7 +668,7 @@ class Unet(nn.Module):
|
||||
rearrange(self.null_image_embed, 'd -> 1 d')
|
||||
)
|
||||
|
||||
cond = torch.cat((t, image_embed), dim = -1)
|
||||
t = torch.cat((t, image_embed), dim = -1)
|
||||
|
||||
hiddens = []
|
||||
|
||||
@@ -663,7 +679,7 @@ class Unet(nn.Module):
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block1(x, t)
|
||||
x = self.attn(x)
|
||||
x = self.mid_attn(x)
|
||||
x = self.mid_block2(x, t)
|
||||
|
||||
for convnext, convnext2, upsample in self.ups:
|
||||
@@ -681,17 +697,18 @@ class Decoder(nn.Module):
|
||||
*,
|
||||
clip,
|
||||
timesteps = 1000,
|
||||
cond_prob_drop = 0.2,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = 'l1'
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
self.clip = clip
|
||||
|
||||
self.net = net
|
||||
self.channels = clip.image_channels
|
||||
self.image_size = clip.image_size
|
||||
self.cond_prob_drop = cond_prob_drop
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
|
||||
@@ -768,7 +785,7 @@ class Decoder(nn.Module):
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, image_embed, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, clip_denoised = clip_denoised)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, clip_denoised = clip_denoised)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
@@ -800,7 +817,7 @@ class Decoder(nn.Module):
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
def p_losses(self, x_start, image_embed, t, noise = None):
|
||||
def p_losses(self, x_start, t, *, image_embed, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||
@@ -809,7 +826,7 @@ class Decoder(nn.Module):
|
||||
x_noisy,
|
||||
t,
|
||||
image_embed = image_embed,
|
||||
cond_prob_drop = self.cond_prob_drop
|
||||
cond_drop_prob = self.cond_drop_prob
|
||||
)
|
||||
|
||||
if self.loss_type == 'l1':
|
||||
@@ -821,14 +838,14 @@ class Decoder(nn.Module):
|
||||
|
||||
return loss
|
||||
|
||||
def forward(self, image, *args, **kwargs):
|
||||
def forward(self, image):
|
||||
b, device, img_size, = image.shape[0], image.device, self.image_size
|
||||
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
image_embed = self.get_image_embed(image)
|
||||
|
||||
loss = self.p_losses(x, times, image_embed = image_embed, *args, **kwargs)
|
||||
loss = self.p_losses(image, times, image_embed = image_embed)
|
||||
return loss
|
||||
|
||||
# main class
|
||||
@@ -839,23 +856,27 @@ class DALLE2(nn.Module):
|
||||
*,
|
||||
prior,
|
||||
decoder,
|
||||
tokenizer = None
|
||||
prior_num_samples = 2
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(prior), DiffusionPrior
|
||||
assert isinstance(decoder), Decoder
|
||||
self.tokenizer = tokenizer
|
||||
assert isinstance(prior, DiffusionPrior)
|
||||
assert isinstance(decoder, Decoder)
|
||||
self.prior = prior.eval()
|
||||
self.decoder = decoder.eval()
|
||||
self.prior_num_samples = prior_num_samples
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
text
|
||||
):
|
||||
if isinstance(text, str):
|
||||
assert exists(self.tokenizer), 'tokenizer must be passed in if you were to pass in the text as a string'
|
||||
text = self.tokenizer.encode(text)
|
||||
device = next(self.parameters()).device
|
||||
|
||||
image_embed = prior.sample(text, num_samples_per_batch = 2)
|
||||
images = decoder.sample(image_embed)
|
||||
if isinstance(text, str) or is_list_str(text):
|
||||
text = [text] if not isinstance(text, (list, tuple)) else text
|
||||
text = tokenizer.tokenize(text).to(device)
|
||||
|
||||
print(text.shape, type(text))
|
||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||
images = self.decoder.sample(image_embed)
|
||||
return images
|
||||
|
||||
4
setup.py
4
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.3',
|
||||
version = '0.0.4',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -24,7 +24,7 @@ setup(
|
||||
install_requires=[
|
||||
'click',
|
||||
'einops>=0.4',
|
||||
'einops-exts',
|
||||
'einops-exts>=0.0.3',
|
||||
'pillow',
|
||||
'torch>=1.10',
|
||||
'torchvision',
|
||||
|
||||
Reference in New Issue
Block a user