mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6e27f617f1 | ||
|
|
9f55c24db6 | ||
|
|
69e822b7f8 | ||
|
|
23c401a5d5 | ||
|
|
68e9883f59 | ||
|
|
95b018374a | ||
|
|
8b5c2385b0 | ||
|
|
f2c52d8239 | ||
|
|
97e951221b | ||
|
|
e1b0c140f1 | ||
|
|
5989569a44 | ||
|
|
82464d7bd3 | ||
|
|
7fb3f695d5 | ||
|
|
7e93b9d3c8 | ||
|
|
4c827ba94f | ||
|
|
cb3923a90f | ||
|
|
cc30676a3f | ||
|
|
c7fb327618 | ||
|
|
14ddbc159c | ||
|
|
0692f1699f | ||
|
|
26c4534bc3 | ||
|
|
5e06cde4cb | ||
|
|
a1a8a78f21 |
55
README.md
55
README.md
@@ -22,19 +22,11 @@ For all of you emailing me (there is a lot), the best way to contribute is throu
|
||||
$ pip install dalle2-pytorch
|
||||
```
|
||||
|
||||
## CLI Usage (work in progress)
|
||||
|
||||
```bash
|
||||
$ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
```
|
||||
|
||||
Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
## Training (for deep learning practitioners)
|
||||
## Usage
|
||||
|
||||
To train DALLE-2 is a 3 step process, with the training of CLIP being the most important
|
||||
|
||||
To train CLIP, you can either use `x-clip` package, or join the LAION discord, where a lot of replication efforts are already underway.
|
||||
To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already <a href="https://github.com/mlfoundations/open_clip">underway</a>.
|
||||
|
||||
This repository will demonstrate integration with `x-clip` for starters
|
||||
|
||||
@@ -109,7 +101,7 @@ clip = CLIP(
|
||||
unet = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
time_dim = 128,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
@@ -136,12 +128,14 @@ loss.backward()
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
```
|
||||
|
||||
Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP fron the first step
|
||||
Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP from the first step
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
|
||||
|
||||
# get trained CLIP from step one
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
@@ -160,7 +154,6 @@ clip = CLIP(
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
num_timesteps = 100,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
@@ -199,7 +192,7 @@ dalle2 = DALLE2(
|
||||
decoder = decoder
|
||||
)
|
||||
|
||||
# send the text as a string if you want to use the simple tokenizer from DALL-E1
|
||||
# send the text as a string if you want to use the simple tokenizer from DALLE v1
|
||||
# or you can do it as token ids, if you have your own tokenizer
|
||||
|
||||
texts = ['glistening morning dew on a flower petal']
|
||||
@@ -212,10 +205,7 @@ Let's see the whole script below
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2
|
||||
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP
|
||||
|
||||
import torch
|
||||
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
@@ -252,7 +242,6 @@ loss.backward()
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
num_timesteps = 100,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
@@ -275,7 +264,7 @@ loss.backward()
|
||||
unet = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
time_dim = 128,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
@@ -287,7 +276,7 @@ decoder = Decoder(
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
loss = decoder(images)
|
||||
loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps
|
||||
@@ -297,13 +286,30 @@ dalle2 = DALLE2(
|
||||
decoder = decoder
|
||||
)
|
||||
|
||||
images = dalle2(['cute puppy chasing after a squirrel'])
|
||||
images = dalle2(
|
||||
['cute puppy chasing after a squirrel'],
|
||||
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
|
||||
)
|
||||
|
||||
# save your image
|
||||
```
|
||||
|
||||
Everything in this readme should run without error
|
||||
|
||||
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
|
||||
|
||||
## CLI Usage (work in progress)
|
||||
|
||||
```bash
|
||||
$ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
```
|
||||
|
||||
Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
## Training wrapper (wip)
|
||||
|
||||
Offer training wrappers
|
||||
|
||||
## Training CLI (wip)
|
||||
|
||||
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
|
||||
@@ -313,11 +319,12 @@ Everything in this readme should run without error
|
||||
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
|
||||
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
|
||||
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
|
||||
- [ ] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
|
||||
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
|
||||
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
|
||||
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] add attention to unet - apply some personal tricks with efficient attention
|
||||
- [ ] figure out the big idea behind latent diffusion and what can be ported over
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -365,3 +372,5 @@ Everything in this readme should run without error
|
||||
primaryClass = {cs.LG}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||
|
||||
@@ -7,9 +7,12 @@ import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
|
||||
from kornia.filters import filter2d
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
|
||||
# use x-clip
|
||||
@@ -115,25 +118,110 @@ class ChanRMSNorm(RMSNorm):
|
||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
|
||||
|
||||
class PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs) + x
|
||||
return self.fn(x, **kwargs) + x
|
||||
|
||||
# mlp
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
*,
|
||||
expansion_factor = 2.,
|
||||
depth = 2,
|
||||
norm = False,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(expansion_factor * dim_out)
|
||||
norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()
|
||||
|
||||
layers = [nn.Sequential(
|
||||
nn.Linear(dim_in, hidden_dim),
|
||||
nn.SiLU(),
|
||||
norm_fn()
|
||||
)]
|
||||
|
||||
for _ in range(depth - 1):
|
||||
layers.append(nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
norm_fn()
|
||||
))
|
||||
|
||||
layers.append(nn.Linear(hidden_dim, dim_out))
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x.float())
|
||||
|
||||
# relative positional bias for causal transformer
|
||||
|
||||
class RelPosBias(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
heads = 8,
|
||||
num_buckets = 32,
|
||||
max_distance = 128,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.max_distance = max_distance
|
||||
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(
|
||||
relative_position,
|
||||
num_buckets = 32,
|
||||
max_distance = 128
|
||||
):
|
||||
n = -relative_position
|
||||
n = torch.max(n, torch.zeros_like(n))
|
||||
|
||||
max_exact = num_buckets // 2
|
||||
is_small = n < max_exact
|
||||
|
||||
val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
|
||||
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
||||
return torch.where(is_small, n, val_if_large)
|
||||
|
||||
def forward(self, i, j, *, device):
|
||||
q_pos = torch.arange(i, dtype = torch.long, device = device)
|
||||
k_pos = torch.arange(j, dtype = torch.long, device = device)
|
||||
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
|
||||
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
||||
values = self.relative_attention_bias(rp_bucket)
|
||||
return rearrange(values, 'i j h -> h i j')
|
||||
|
||||
# feedforward
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
""" used successfully in https://arxiv.org/abs/2204.0231 """
|
||||
def forward(self, x):
|
||||
x, gate = x.chunk(2, dim = -1)
|
||||
return x * F.silu(gate)
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
|
||||
""" post-activation norm https://arxiv.org/abs/2110.09456 """
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
inner_dim = int(mult * dim)
|
||||
return nn.Sequential(
|
||||
RMSNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias = False),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim, inner_dim * 2, bias = False),
|
||||
SwiGLU(),
|
||||
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim, bias = False)
|
||||
)
|
||||
|
||||
# attention
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -158,7 +246,7 @@ class Attention(nn.Module):
|
||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
def forward(self, x, mask = None, attn_bias = None):
|
||||
b, n, device = *x.shape[:2], x.device
|
||||
|
||||
x = self.norm(x)
|
||||
@@ -175,6 +263,14 @@ class Attention(nn.Module):
|
||||
q = q * self.scale
|
||||
|
||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||
|
||||
# relative positional encoding (T5 style)
|
||||
|
||||
if exists(attn_bias):
|
||||
sim = sim + attn_bias
|
||||
|
||||
# masking
|
||||
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
|
||||
if exists(mask):
|
||||
@@ -187,8 +283,13 @@ class Attention(nn.Module):
|
||||
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
|
||||
sim = sim.masked_fill(causal_mask, max_neg_value)
|
||||
|
||||
# attention
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||
|
||||
@@ -209,7 +310,7 @@ class CausalTransformer(nn.Module):
|
||||
ff_dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
# todo - bring in rotary embeddings or alibi
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
@@ -225,8 +326,12 @@ class CausalTransformer(nn.Module):
|
||||
x,
|
||||
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
||||
):
|
||||
n, device = x.shape[1], x.device
|
||||
|
||||
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, mask = mask) + x
|
||||
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
@@ -235,26 +340,26 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_timesteps = 1000,
|
||||
num_timesteps = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
x,
|
||||
*,
|
||||
*args,
|
||||
cond_scale = 1.,
|
||||
**kwargs
|
||||
):
|
||||
if cond_scale == 1:
|
||||
return self.forward(x, **kwargs)
|
||||
logits = self.forward(*args, **kwargs)
|
||||
|
||||
logits = self.forward(x, **kwargs)
|
||||
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
|
||||
if cond_scale == 1:
|
||||
return logits
|
||||
|
||||
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
|
||||
return null_logits + (logits - null_logits) * cond_scale
|
||||
|
||||
def forward(
|
||||
@@ -274,8 +379,15 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
|
||||
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
not_all_masked_out = mask.any(dim = -1)
|
||||
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.time_embeddings(diffusion_timesteps)
|
||||
time_embed = rearrange(time_embed, 'b d -> b 1 d')
|
||||
@@ -374,12 +486,13 @@ class DiffusionPrior(nn.Module):
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return image_embed
|
||||
return l2norm(image_embed)
|
||||
|
||||
def get_text_cond(self, text):
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
||||
text_embed = self.clip.to_text_latent(text_cls)
|
||||
text_embed = l2norm(text_embed)
|
||||
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
@@ -512,6 +625,17 @@ def Upsample(dim):
|
||||
def Downsample(dim):
|
||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
||||
|
||||
class Blur(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
filt = torch.Tensor([1, 2, 1])
|
||||
self.register_buffer('filt', filt)
|
||||
|
||||
def forward(self, x):
|
||||
filt = self.filt
|
||||
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
|
||||
return filter2d(x, filt, normalized = True)
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@@ -539,10 +663,17 @@ class ConvNextBlock(nn.Module):
|
||||
super().__init__()
|
||||
need_projection = dim != dim_out
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.GELU(),
|
||||
nn.Linear(cond_dim, dim)
|
||||
) if exists(cond_dim) else None
|
||||
self.cross_attn = None
|
||||
|
||||
if exists(cond_dim):
|
||||
self.cross_attn = EinopsToAndFrom(
|
||||
'b c h w',
|
||||
'b (h w) c',
|
||||
CrossAttention(
|
||||
dim = dim,
|
||||
context_dim = cond_dim
|
||||
)
|
||||
)
|
||||
|
||||
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
||||
|
||||
@@ -559,21 +690,82 @@ class ConvNextBlock(nn.Module):
|
||||
def forward(self, x, cond = None):
|
||||
h = self.ds_conv(x)
|
||||
|
||||
if exists(self.mlp):
|
||||
if exists(self.cross_attn):
|
||||
assert exists(cond)
|
||||
condition = self.mlp(cond)
|
||||
h = h + rearrange(condition, 'b c -> b c 1 1')
|
||||
h = self.cross_attn(h, context = cond) + h
|
||||
|
||||
h = self.net(h)
|
||||
|
||||
return h + self.res_conv(x)
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
context_dim = None,
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
context_dim = default(context_dim, dim)
|
||||
|
||||
self.norm = RMSNorm(dim)
|
||||
self.norm_context = RMSNorm(context_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x, context, mask = None):
|
||||
b, n, device = *x.shape[:2], x.device
|
||||
|
||||
x = self.norm(x)
|
||||
context = self.norm_context(context)
|
||||
|
||||
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
||||
|
||||
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
|
||||
|
||||
# add null key / value for classifier free guidance in prior net
|
||||
|
||||
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
|
||||
|
||||
k = torch.cat((nk, k), dim = -2)
|
||||
v = torch.cat((nv, v), dim = -2)
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (1, 0), value = True)
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Unet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
image_embed_dim,
|
||||
time_dim = None,
|
||||
cond_dim = None,
|
||||
num_image_tokens = 4,
|
||||
out_dim = None,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
channels = 3,
|
||||
@@ -584,18 +776,31 @@ class Unet(nn.Module):
|
||||
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
time_dim = default(time_dim, dim)
|
||||
# time, image embeddings, and optional text encoding
|
||||
|
||||
cond_dim = default(cond_dim, dim)
|
||||
|
||||
self.time_mlp = nn.Sequential(
|
||||
SinusoidalPosEmb(dim),
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim * 4, dim)
|
||||
nn.Linear(dim * 4, cond_dim),
|
||||
Rearrange('b d -> b 1 d')
|
||||
)
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(image_embed_dim))
|
||||
self.image_to_cond = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
cond_dim = time_dim + image_embed_dim
|
||||
self.text_to_cond = nn.LazyLinear(cond_dim)
|
||||
|
||||
# for classifier free guidance
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
||||
|
||||
# layers
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
self.ups = nn.ModuleList([])
|
||||
@@ -605,7 +810,7 @@ class Unet(nn.Module):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_in, dim_out, cond_dim = cond_dim, norm = ind != 0),
|
||||
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
||||
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
@@ -613,7 +818,7 @@ class Unet(nn.Module):
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim)))
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim)))
|
||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
@@ -633,16 +838,16 @@ class Unet(nn.Module):
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
x,
|
||||
*,
|
||||
*args,
|
||||
cond_scale = 1.,
|
||||
**kwargs
|
||||
):
|
||||
if cond_scale == 1:
|
||||
return self.forward(x, **kwargs)
|
||||
logits = self.forward(*args, **kwargs)
|
||||
|
||||
logits = self.forward(x, **kwargs)
|
||||
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
|
||||
if cond_scale == 1:
|
||||
return logits
|
||||
|
||||
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
|
||||
return null_logits + (logits - null_logits) * cond_scale
|
||||
|
||||
def forward(
|
||||
@@ -655,37 +860,59 @@ class Unet(nn.Module):
|
||||
cond_drop_prob = 0.
|
||||
):
|
||||
batch_size, device = x.shape[0], x.device
|
||||
t = self.time_mlp(time)
|
||||
time_tokens = self.time_mlp(time)
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
|
||||
|
||||
# mask out image embedding depending on condition dropout
|
||||
# for classifier free guidance
|
||||
|
||||
image_embed = torch.where(
|
||||
rearrange(cond_prob_mask, 'b -> b 1'),
|
||||
image_embed,
|
||||
rearrange(self.null_image_embed, 'd -> 1 d')
|
||||
image_tokens = self.image_to_cond(image_embed)
|
||||
|
||||
image_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
image_tokens,
|
||||
self.null_image_embed
|
||||
)
|
||||
|
||||
t = torch.cat((t, image_embed), dim = -1)
|
||||
# take care of text encodings (optional)
|
||||
|
||||
if exists(text_encodings):
|
||||
text_tokens = self.text_to_cond(text_encodings)
|
||||
text_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
text_tokens,
|
||||
self.null_text_embed
|
||||
)
|
||||
|
||||
# main conditioning tokens (c)
|
||||
|
||||
c = torch.cat((time_tokens, image_tokens), dim = -2)
|
||||
|
||||
# text and image conditioning tokens (mid_c)
|
||||
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
|
||||
|
||||
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
|
||||
|
||||
# go through the layers of the unet, down and up
|
||||
|
||||
hiddens = []
|
||||
|
||||
for convnext, convnext2, downsample in self.downs:
|
||||
x = convnext(x, t)
|
||||
x = convnext2(x, t)
|
||||
x = convnext(x, c)
|
||||
x = convnext2(x, c)
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block1(x, t)
|
||||
x = self.mid_block1(x, mid_c)
|
||||
x = self.mid_attn(x)
|
||||
x = self.mid_block2(x, t)
|
||||
x = self.mid_block2(x, mid_c)
|
||||
|
||||
for convnext, convnext2, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = convnext(x, t)
|
||||
x = convnext2(x, t)
|
||||
x = convnext(x, c)
|
||||
x = convnext2(x, c)
|
||||
x = upsample(x)
|
||||
|
||||
return self.final_conv(x)
|
||||
@@ -746,11 +973,15 @@ class Decoder(nn.Module):
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
def get_text_encodings(self, text):
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
return text_encodings[:, 1:]
|
||||
|
||||
def get_image_embed(self, image):
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return image_embed
|
||||
return l2norm(image_embed)
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
@@ -773,8 +1004,8 @@ class Decoder(nn.Module):
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, t, image_embed, clip_denoised: bool):
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, image_embed = image_embed))
|
||||
def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.):
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale))
|
||||
|
||||
if clip_denoised:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
@@ -783,31 +1014,32 @@ class Decoder(nn.Module):
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, image_embed, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, clip_denoised = clip_denoised)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, shape, image_embed):
|
||||
def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device=device)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed)
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, image_embed):
|
||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||
batch_size = image_embed.shape[0]
|
||||
image_size = self.image_size
|
||||
channels = self.channels
|
||||
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed)
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
@@ -817,7 +1049,7 @@ class Decoder(nn.Module):
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
def p_losses(self, x_start, t, *, image_embed, noise = None):
|
||||
def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||
@@ -826,6 +1058,7 @@ class Decoder(nn.Module):
|
||||
x_noisy,
|
||||
t,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_drop_prob = self.cond_drop_prob
|
||||
)
|
||||
|
||||
@@ -838,14 +1071,16 @@ class Decoder(nn.Module):
|
||||
|
||||
return loss
|
||||
|
||||
def forward(self, image):
|
||||
def forward(self, image, text = None):
|
||||
b, device, img_size, = image.shape[0], image.device, self.image_size
|
||||
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
image_embed = self.get_image_embed(image)
|
||||
|
||||
loss = self.p_losses(image, times, image_embed = image_embed)
|
||||
image_embed = self.get_image_embed(image)
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
|
||||
loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings)
|
||||
return loss
|
||||
|
||||
# main class
|
||||
@@ -861,14 +1096,16 @@ class DALLE2(nn.Module):
|
||||
super().__init__()
|
||||
assert isinstance(prior, DiffusionPrior)
|
||||
assert isinstance(decoder, Decoder)
|
||||
self.prior = prior.eval()
|
||||
self.decoder = decoder.eval()
|
||||
self.prior = prior
|
||||
self.decoder = decoder
|
||||
self.prior_num_samples = prior_num_samples
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def forward(
|
||||
self,
|
||||
text
|
||||
text,
|
||||
cond_scale = 1.
|
||||
):
|
||||
device = next(self.parameters()).device
|
||||
|
||||
@@ -876,7 +1113,6 @@ class DALLE2(nn.Module):
|
||||
text = [text] if not isinstance(text, (list, tuple)) else text
|
||||
text = tokenizer.tokenize(text).to(device)
|
||||
|
||||
print(text.shape, type(text))
|
||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||
images = self.decoder.sample(image_embed)
|
||||
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
|
||||
return images
|
||||
|
||||
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.4',
|
||||
version = '0.0.15',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -25,6 +25,7 @@ setup(
|
||||
'click',
|
||||
'einops>=0.4',
|
||||
'einops-exts>=0.0.3',
|
||||
'kornia>=0.5.4',
|
||||
'pillow',
|
||||
'torch>=1.10',
|
||||
'torchvision',
|
||||
|
||||
Reference in New Issue
Block a user