mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 16:24:21 +01:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c400d8758c | ||
|
|
bece206699 | ||
|
|
5b4ee09625 | ||
|
|
6e27f617f1 |
@@ -325,6 +325,7 @@ Offer training wrappers
|
|||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
- [ ] add attention to unet - apply some personal tricks with efficient attention
|
- [ ] add attention to unet - apply some personal tricks with efficient attention
|
||||||
- [ ] figure out the big idea behind latent diffusion and what can be ported over
|
- [ ] figure out the big idea behind latent diffusion and what can be ported over
|
||||||
|
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from einops.layers.torch import Rearrange
|
|||||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||||
from einops_exts.torch import EinopsToAndFrom
|
from einops_exts.torch import EinopsToAndFrom
|
||||||
|
|
||||||
from kornia.filters import filter2d
|
from kornia.filters.gaussian import GaussianBlur2d
|
||||||
|
|
||||||
from dalle2_pytorch.tokenizer import tokenizer
|
from dalle2_pytorch.tokenizer import tokenizer
|
||||||
|
|
||||||
@@ -161,6 +161,44 @@ class MLP(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x.float())
|
return self.net(x.float())
|
||||||
|
|
||||||
|
# relative positional bias for causal transformer
|
||||||
|
|
||||||
|
class RelPosBias(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
heads = 8,
|
||||||
|
num_buckets = 32,
|
||||||
|
max_distance = 128,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_buckets = num_buckets
|
||||||
|
self.max_distance = max_distance
|
||||||
|
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _relative_position_bucket(
|
||||||
|
relative_position,
|
||||||
|
num_buckets = 32,
|
||||||
|
max_distance = 128
|
||||||
|
):
|
||||||
|
n = -relative_position
|
||||||
|
n = torch.max(n, torch.zeros_like(n))
|
||||||
|
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = n < max_exact
|
||||||
|
|
||||||
|
val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
|
||||||
|
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
||||||
|
return torch.where(is_small, n, val_if_large)
|
||||||
|
|
||||||
|
def forward(self, i, j, *, device):
|
||||||
|
q_pos = torch.arange(i, dtype = torch.long, device = device)
|
||||||
|
k_pos = torch.arange(j, dtype = torch.long, device = device)
|
||||||
|
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
|
||||||
|
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
||||||
|
values = self.relative_attention_bias(rp_bucket)
|
||||||
|
return rearrange(values, 'i j h -> h i j')
|
||||||
|
|
||||||
# feedforward
|
# feedforward
|
||||||
|
|
||||||
class SwiGLU(nn.Module):
|
class SwiGLU(nn.Module):
|
||||||
@@ -208,7 +246,7 @@ class Attention(nn.Module):
|
|||||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||||
|
|
||||||
def forward(self, x, mask = None):
|
def forward(self, x, mask = None, attn_bias = None):
|
||||||
b, n, device = *x.shape[:2], x.device
|
b, n, device = *x.shape[:2], x.device
|
||||||
|
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
@@ -225,6 +263,14 @@ class Attention(nn.Module):
|
|||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
|
||||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||||
|
|
||||||
|
# relative positional encoding (T5 style)
|
||||||
|
|
||||||
|
if exists(attn_bias):
|
||||||
|
sim = sim + attn_bias
|
||||||
|
|
||||||
|
# masking
|
||||||
|
|
||||||
max_neg_value = -torch.finfo(sim.dtype).max
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
|
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
@@ -237,10 +283,14 @@ class Attention(nn.Module):
|
|||||||
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
|
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
|
||||||
sim = sim.masked_fill(causal_mask, max_neg_value)
|
sim = sim.masked_fill(causal_mask, max_neg_value)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
|
||||||
sim = sim - sim.amax(dim = -1, keepdim = True)
|
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||||
attn = sim.softmax(dim = -1)
|
attn = sim.softmax(dim = -1)
|
||||||
attn = self.dropout(attn)
|
attn = self.dropout(attn)
|
||||||
|
|
||||||
|
# aggregate values
|
||||||
|
|
||||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||||
|
|
||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
@@ -260,7 +310,7 @@ class CausalTransformer(nn.Module):
|
|||||||
ff_dropout = 0.
|
ff_dropout = 0.
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# todo - bring in rotary embeddings or alibi
|
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||||
|
|
||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
for _ in range(depth):
|
for _ in range(depth):
|
||||||
@@ -276,8 +326,12 @@ class CausalTransformer(nn.Module):
|
|||||||
x,
|
x,
|
||||||
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
||||||
):
|
):
|
||||||
|
n, device = x.shape[1], x.device
|
||||||
|
|
||||||
|
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
||||||
|
|
||||||
for attn, ff in self.layers:
|
for attn, ff in self.layers:
|
||||||
x = attn(x, mask = mask) + x
|
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
||||||
x = ff(x) + x
|
x = ff(x) + x
|
||||||
|
|
||||||
return self.norm(x)
|
return self.norm(x)
|
||||||
@@ -396,7 +450,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
alphas = 1. - betas
|
alphas = 1. - betas
|
||||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||||
|
|
||||||
timesteps, = betas.shape
|
timesteps, = betas.shape
|
||||||
self.num_timesteps = int(timesteps)
|
self.num_timesteps = int(timesteps)
|
||||||
@@ -571,17 +625,6 @@ def Upsample(dim):
|
|||||||
def Downsample(dim):
|
def Downsample(dim):
|
||||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
return nn.Conv2d(dim, dim, 4, 2, 1)
|
||||||
|
|
||||||
class Blur(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
filt = torch.Tensor([1, 2, 1])
|
|
||||||
self.register_buffer('filt', filt)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
filt = self.filt
|
|
||||||
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
|
|
||||||
return filter2d(x, filt, normalized = True)
|
|
||||||
|
|
||||||
class SinusoidalPosEmb(nn.Module):
|
class SinusoidalPosEmb(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -715,11 +758,25 @@ class Unet(nn.Module):
|
|||||||
out_dim = None,
|
out_dim = None,
|
||||||
dim_mults=(1, 2, 4, 8),
|
dim_mults=(1, 2, 4, 8),
|
||||||
channels = 3,
|
channels = 3,
|
||||||
|
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||||
|
lowres_cond_upsample_mode = 'bilinear',
|
||||||
|
blur_sigma = 0.1
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# for eventual cascading diffusion
|
||||||
|
|
||||||
|
self.lowres_cond = lowres_cond
|
||||||
|
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
|
||||||
|
self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma))
|
||||||
|
|
||||||
|
# determine dimensions
|
||||||
|
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
|
||||||
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||||
|
|
||||||
|
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
|
||||||
in_out = list(zip(dims[:-1], dims[1:]))
|
in_out = list(zip(dims[:-1], dims[1:]))
|
||||||
|
|
||||||
# time, image embeddings, and optional text encoding
|
# time, image embeddings, and optional text encoding
|
||||||
@@ -802,12 +859,30 @@ class Unet(nn.Module):
|
|||||||
time,
|
time,
|
||||||
*,
|
*,
|
||||||
image_embed,
|
image_embed,
|
||||||
|
lowres_cond_img = None,
|
||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
cond_drop_prob = 0.
|
cond_drop_prob = 0.
|
||||||
):
|
):
|
||||||
batch_size, device = x.shape[0], x.device
|
batch_size, device = x.shape[0], x.device
|
||||||
|
|
||||||
|
# add low resolution conditioning, if present
|
||||||
|
|
||||||
|
assert not self.lowres_cond and not exists(lowres_cond_img), 'low resolution conditioning image must be present'
|
||||||
|
|
||||||
|
if exists(lowres_cond_img):
|
||||||
|
if self.training:
|
||||||
|
# when training, blur the low resolution conditional image
|
||||||
|
lowres_cond_img = self.lowres_cond_blur(lowres_cond_img)
|
||||||
|
|
||||||
|
lowres_cond_img = F.interpolate(lowres_cond_img, size = x.shape[-2:], mode = self.lowres_cond_upsample_mode)
|
||||||
|
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||||
|
|
||||||
|
# time conditioning
|
||||||
|
|
||||||
time_tokens = self.time_mlp(time)
|
time_tokens = self.time_mlp(time)
|
||||||
|
|
||||||
|
# conditional dropout
|
||||||
|
|
||||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
||||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
|
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
|
||||||
|
|
||||||
@@ -887,7 +962,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
alphas = 1. - betas
|
alphas = 1. - betas
|
||||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||||
|
|
||||||
timesteps, = betas.shape
|
timesteps, = betas.shape
|
||||||
self.num_timesteps = int(timesteps)
|
self.num_timesteps = int(timesteps)
|
||||||
|
|||||||
Reference in New Issue
Block a user