Compare commits

...

4 Commits

3 changed files with 95 additions and 19 deletions

View File

@@ -325,6 +325,7 @@ Offer training wrappers
- [ ] train on a toy task, offer in colab
- [ ] add attention to unet - apply some personal tricks with efficient attention
- [ ] figure out the big idea behind latent diffusion and what can be ported over
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007
## Citations

View File

@@ -11,7 +11,7 @@ from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
from kornia.filters import filter2d
from kornia.filters.gaussian import GaussianBlur2d
from dalle2_pytorch.tokenizer import tokenizer
@@ -161,6 +161,44 @@ class MLP(nn.Module):
def forward(self, x):
return self.net(x.float())
# relative positional bias for causal transformer
class RelPosBias(nn.Module):
def __init__(
self,
heads = 8,
num_buckets = 32,
max_distance = 128,
):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(
relative_position,
num_buckets = 32,
max_distance = 128
):
n = -relative_position
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
return torch.where(is_small, n, val_if_large)
def forward(self, i, j, *, device):
q_pos = torch.arange(i, dtype = torch.long, device = device)
k_pos = torch.arange(j, dtype = torch.long, device = device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
return rearrange(values, 'i j h -> h i j')
# feedforward
class SwiGLU(nn.Module):
@@ -208,7 +246,7 @@ class Attention(nn.Module):
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, mask = None):
def forward(self, x, mask = None, attn_bias = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
@@ -225,6 +263,14 @@ class Attention(nn.Module):
q = q * self.scale
sim = einsum('b h i d, b j d -> b h i j', q, k)
# relative positional encoding (T5 style)
if exists(attn_bias):
sim = sim + attn_bias
# masking
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
@@ -237,10 +283,14 @@ class Attention(nn.Module):
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, max_neg_value)
# attention
sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate values
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
@@ -260,7 +310,7 @@ class CausalTransformer(nn.Module):
ff_dropout = 0.
):
super().__init__()
# todo - bring in rotary embeddings or alibi
self.rel_pos_bias = RelPosBias(heads = heads)
self.layers = nn.ModuleList([])
for _ in range(depth):
@@ -276,8 +326,12 @@ class CausalTransformer(nn.Module):
x,
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
):
n, device = x.shape[1], x.device
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
for attn, ff in self.layers:
x = attn(x, mask = mask) + x
x = attn(x, mask = mask, attn_bias = attn_bias) + x
x = ff(x) + x
return self.norm(x)
@@ -396,7 +450,7 @@ class DiffusionPrior(nn.Module):
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
@@ -571,17 +625,6 @@ def Upsample(dim):
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)
class Blur(nn.Module):
def __init__(self):
super().__init__()
filt = torch.Tensor([1, 2, 1])
self.register_buffer('filt', filt)
def forward(self, x):
filt = self.filt
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
return filter2d(x, filt, normalized = True)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -715,11 +758,25 @@ class Unet(nn.Module):
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1
):
super().__init__()
# for eventual cascading diffusion
self.lowres_cond = lowres_cond
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma))
# determine dimensions
self.channels = channels
dims = [channels, *map(lambda m: dim * m, dim_mults)]
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# time, image embeddings, and optional text encoding
@@ -802,12 +859,30 @@ class Unet(nn.Module):
time,
*,
image_embed,
lowres_cond_img = None,
text_encodings = None,
cond_drop_prob = 0.
):
batch_size, device = x.shape[0], x.device
# add low resolution conditioning, if present
assert not self.lowres_cond and not exists(lowres_cond_img), 'low resolution conditioning image must be present'
if exists(lowres_cond_img):
if self.training:
# when training, blur the low resolution conditional image
lowres_cond_img = self.lowres_cond_blur(lowres_cond_img)
lowres_cond_img = F.interpolate(lowres_cond_img, size = x.shape[-2:], mode = self.lowres_cond_upsample_mode)
x = torch.cat((x, lowres_cond_img), dim = 1)
# time conditioning
time_tokens = self.time_mlp(time)
# conditional dropout
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
@@ -887,7 +962,7 @@ class Decoder(nn.Module):
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.14',
version = '0.0.17',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',