mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
23c401a5d5 | ||
|
|
68e9883f59 | ||
|
|
95b018374a | ||
|
|
8b5c2385b0 | ||
|
|
f2c52d8239 | ||
|
|
97e951221b | ||
|
|
e1b0c140f1 | ||
|
|
5989569a44 | ||
|
|
82464d7bd3 | ||
|
|
7fb3f695d5 |
30
README.md
30
README.md
@@ -22,19 +22,11 @@ For all of you emailing me (there is a lot), the best way to contribute is throu
|
||||
$ pip install dalle2-pytorch
|
||||
```
|
||||
|
||||
## CLI Usage (work in progress)
|
||||
|
||||
```bash
|
||||
$ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
```
|
||||
|
||||
Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
## Training (for deep learning practitioners)
|
||||
## Usage
|
||||
|
||||
To train DALLE-2 is a 3 step process, with the training of CLIP being the most important
|
||||
|
||||
To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already underway.
|
||||
To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already <a href="https://github.com/mlfoundations/open_clip">underway</a>.
|
||||
|
||||
This repository will demonstrate integration with `x-clip` for starters
|
||||
|
||||
@@ -109,7 +101,7 @@ clip = CLIP(
|
||||
unet = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
time_dim = 128,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
@@ -162,7 +154,6 @@ clip = CLIP(
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
num_timesteps = 100,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
@@ -251,7 +242,6 @@ loss.backward()
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
num_timesteps = 100,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
@@ -274,7 +264,7 @@ loss.backward()
|
||||
unet = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
time_dim = 128,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
).cuda()
|
||||
@@ -308,6 +298,18 @@ Everything in this readme should run without error
|
||||
|
||||
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
|
||||
|
||||
## CLI Usage (work in progress)
|
||||
|
||||
```bash
|
||||
$ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
```
|
||||
|
||||
Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
## Training wrapper (wip)
|
||||
|
||||
Offer training wrappers
|
||||
|
||||
## Training CLI (wip)
|
||||
|
||||
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
|
||||
|
||||
@@ -7,9 +7,12 @@ import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
|
||||
from kornia.filters import filter2d
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
|
||||
# use x-clip
|
||||
@@ -115,25 +118,72 @@ class ChanRMSNorm(RMSNorm):
|
||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
|
||||
|
||||
class PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs) + x
|
||||
return self.fn(x, **kwargs) + x
|
||||
|
||||
# mlp
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
*,
|
||||
expansion_factor = 2.,
|
||||
depth = 2,
|
||||
norm = False,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(expansion_factor * dim_out)
|
||||
norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()
|
||||
|
||||
layers = [nn.Sequential(
|
||||
nn.Linear(dim_in, hidden_dim),
|
||||
nn.SiLU(),
|
||||
norm_fn()
|
||||
)]
|
||||
|
||||
for _ in range(depth - 1):
|
||||
layers.append(nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.SiLU(),
|
||||
norm_fn()
|
||||
))
|
||||
|
||||
layers.append(nn.Linear(hidden_dim, dim_out))
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x.float())
|
||||
|
||||
# feedforward
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
""" used successfully in https://arxiv.org/abs/2204.0231 """
|
||||
def forward(self, x):
|
||||
x, gate = x.chunk(2, dim = -1)
|
||||
return x * F.silu(gate)
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
|
||||
""" post-activation norm https://arxiv.org/abs/2110.09456 """
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
inner_dim = int(mult * dim)
|
||||
return nn.Sequential(
|
||||
RMSNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias = False),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim, inner_dim * 2, bias = False),
|
||||
SwiGLU(),
|
||||
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim, bias = False)
|
||||
)
|
||||
|
||||
# attention
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -189,6 +239,7 @@ class Attention(nn.Module):
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||
|
||||
@@ -235,27 +286,26 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_timesteps = 1000,
|
||||
num_timesteps = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
x,
|
||||
*args,
|
||||
cond_scale = 1.,
|
||||
**kwargs
|
||||
):
|
||||
logits = self.forward(x, *args, **kwargs)
|
||||
logits = self.forward(*args, **kwargs)
|
||||
|
||||
if cond_scale == 1:
|
||||
return logits
|
||||
|
||||
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
|
||||
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
|
||||
return null_logits + (logits - null_logits) * cond_scale
|
||||
|
||||
def forward(
|
||||
@@ -275,8 +325,15 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
|
||||
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
not_all_masked_out = mask.any(dim = -1)
|
||||
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.time_embeddings(diffusion_timesteps)
|
||||
time_embed = rearrange(time_embed, 'b d -> b 1 d')
|
||||
@@ -514,6 +571,17 @@ def Upsample(dim):
|
||||
def Downsample(dim):
|
||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
||||
|
||||
class Blur(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
filt = torch.Tensor([1, 2, 1])
|
||||
self.register_buffer('filt', filt)
|
||||
|
||||
def forward(self, x):
|
||||
filt = self.filt
|
||||
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
|
||||
return filter2d(x, filt, normalized = True)
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@@ -541,10 +609,17 @@ class ConvNextBlock(nn.Module):
|
||||
super().__init__()
|
||||
need_projection = dim != dim_out
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.GELU(),
|
||||
nn.Linear(cond_dim, dim)
|
||||
) if exists(cond_dim) else None
|
||||
self.cross_attn = None
|
||||
|
||||
if exists(cond_dim):
|
||||
self.cross_attn = EinopsToAndFrom(
|
||||
'b c h w',
|
||||
'b (h w) c',
|
||||
CrossAttention(
|
||||
dim = dim,
|
||||
context_dim = cond_dim
|
||||
)
|
||||
)
|
||||
|
||||
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
||||
|
||||
@@ -561,21 +636,82 @@ class ConvNextBlock(nn.Module):
|
||||
def forward(self, x, cond = None):
|
||||
h = self.ds_conv(x)
|
||||
|
||||
if exists(self.mlp):
|
||||
if exists(self.cross_attn):
|
||||
assert exists(cond)
|
||||
condition = self.mlp(cond)
|
||||
h = h + rearrange(condition, 'b c -> b c 1 1')
|
||||
h = self.cross_attn(h, context = cond) + h
|
||||
|
||||
h = self.net(h)
|
||||
|
||||
return h + self.res_conv(x)
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
context_dim = None,
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
context_dim = default(context_dim, dim)
|
||||
|
||||
self.norm = RMSNorm(dim)
|
||||
self.norm_context = RMSNorm(context_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x, context, mask = None):
|
||||
b, n, device = *x.shape[:2], x.device
|
||||
|
||||
x = self.norm(x)
|
||||
context = self.norm_context(context)
|
||||
|
||||
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
||||
|
||||
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
|
||||
|
||||
# add null key / value for classifier free guidance in prior net
|
||||
|
||||
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
|
||||
|
||||
k = torch.cat((nk, k), dim = -2)
|
||||
v = torch.cat((nv, v), dim = -2)
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (1, 0), value = True)
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Unet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
image_embed_dim,
|
||||
time_dim = None,
|
||||
cond_dim = None,
|
||||
num_image_tokens = 4,
|
||||
out_dim = None,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
channels = 3,
|
||||
@@ -586,18 +722,28 @@ class Unet(nn.Module):
|
||||
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
time_dim = default(time_dim, dim)
|
||||
# time and image embeddings
|
||||
|
||||
cond_dim = default(cond_dim, dim)
|
||||
|
||||
self.time_mlp = nn.Sequential(
|
||||
SinusoidalPosEmb(dim),
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim * 4, dim)
|
||||
nn.Linear(dim * 4, cond_dim),
|
||||
Rearrange('b d -> b 1 d')
|
||||
)
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(image_embed_dim))
|
||||
self.image_to_cond = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
cond_dim = time_dim + image_embed_dim
|
||||
# for classifier free guidance
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||
|
||||
# layers
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
self.ups = nn.ModuleList([])
|
||||
@@ -607,7 +753,7 @@ class Unet(nn.Module):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_in, dim_out, cond_dim = cond_dim, norm = ind != 0),
|
||||
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
||||
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
@@ -615,7 +761,7 @@ class Unet(nn.Module):
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim)))
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim)))
|
||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
@@ -635,17 +781,16 @@ class Unet(nn.Module):
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
x,
|
||||
*args,
|
||||
cond_scale = 1.,
|
||||
**kwargs
|
||||
):
|
||||
logits = self.forward(x, *args, **kwargs)
|
||||
logits = self.forward(*args, **kwargs)
|
||||
|
||||
if cond_scale == 1:
|
||||
return logits
|
||||
|
||||
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
|
||||
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
|
||||
return null_logits + (logits - null_logits) * cond_scale
|
||||
|
||||
def forward(
|
||||
@@ -658,37 +803,39 @@ class Unet(nn.Module):
|
||||
cond_drop_prob = 0.
|
||||
):
|
||||
batch_size, device = x.shape[0], x.device
|
||||
t = self.time_mlp(time)
|
||||
time_tokens = self.time_mlp(time)
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
||||
|
||||
# mask out image embedding depending on condition dropout
|
||||
# for classifier free guidance
|
||||
|
||||
image_embed = torch.where(
|
||||
rearrange(cond_prob_mask, 'b -> b 1'),
|
||||
image_embed,
|
||||
rearrange(self.null_image_embed, 'd -> 1 d')
|
||||
image_tokens = self.image_to_cond(image_embed)
|
||||
|
||||
image_tokens = torch.where(
|
||||
rearrange(cond_prob_mask, 'b -> b 1 1'),
|
||||
image_tokens,
|
||||
self.null_image_embed
|
||||
)
|
||||
|
||||
t = torch.cat((t, image_embed), dim = -1)
|
||||
c = torch.cat((time_tokens, image_tokens), dim = -2) # c for condition
|
||||
|
||||
hiddens = []
|
||||
|
||||
for convnext, convnext2, downsample in self.downs:
|
||||
x = convnext(x, t)
|
||||
x = convnext2(x, t)
|
||||
x = convnext(x, c)
|
||||
x = convnext2(x, c)
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block1(x, t)
|
||||
x = self.mid_block1(x, c)
|
||||
x = self.mid_attn(x)
|
||||
x = self.mid_block2(x, t)
|
||||
x = self.mid_block2(x, c)
|
||||
|
||||
for convnext, convnext2, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = convnext(x, t)
|
||||
x = convnext2(x, t)
|
||||
x = convnext(x, c)
|
||||
x = convnext2(x, c)
|
||||
x = upsample(x)
|
||||
|
||||
return self.final_conv(x)
|
||||
@@ -864,11 +1011,12 @@ class DALLE2(nn.Module):
|
||||
super().__init__()
|
||||
assert isinstance(prior, DiffusionPrior)
|
||||
assert isinstance(decoder, Decoder)
|
||||
self.prior = prior.eval()
|
||||
self.decoder = decoder.eval()
|
||||
self.prior = prior
|
||||
self.decoder = decoder
|
||||
self.prior_num_samples = prior_num_samples
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def forward(
|
||||
self,
|
||||
text,
|
||||
|
||||
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.7',
|
||||
version = '0.0.12',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -25,6 +25,7 @@ setup(
|
||||
'click',
|
||||
'einops>=0.4',
|
||||
'einops-exts>=0.0.3',
|
||||
'kornia>=0.5.4',
|
||||
'pillow',
|
||||
'torch>=1.10',
|
||||
'torchvision',
|
||||
|
||||
Reference in New Issue
Block a user