mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 19:44:26 +01:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a95de92964 | ||
|
|
5b4ee09625 | ||
|
|
6e27f617f1 | ||
|
|
9f55c24db6 | ||
|
|
69e822b7f8 |
@@ -276,7 +276,7 @@ decoder = Decoder(
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
loss = decoder(images)
|
||||
loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps
|
||||
@@ -319,12 +319,13 @@ Offer training wrappers
|
||||
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
|
||||
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
|
||||
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
|
||||
- [ ] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
|
||||
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
|
||||
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
|
||||
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] add attention to unet - apply some personal tricks with efficient attention
|
||||
- [ ] figure out the big idea behind latent diffusion and what can be ported over
|
||||
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
@@ -161,6 +161,44 @@ class MLP(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x.float())
|
||||
|
||||
# relative positional bias for causal transformer
|
||||
|
||||
class RelPosBias(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
heads = 8,
|
||||
num_buckets = 32,
|
||||
max_distance = 128,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.max_distance = max_distance
|
||||
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(
|
||||
relative_position,
|
||||
num_buckets = 32,
|
||||
max_distance = 128
|
||||
):
|
||||
n = -relative_position
|
||||
n = torch.max(n, torch.zeros_like(n))
|
||||
|
||||
max_exact = num_buckets // 2
|
||||
is_small = n < max_exact
|
||||
|
||||
val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
|
||||
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
||||
return torch.where(is_small, n, val_if_large)
|
||||
|
||||
def forward(self, i, j, *, device):
|
||||
q_pos = torch.arange(i, dtype = torch.long, device = device)
|
||||
k_pos = torch.arange(j, dtype = torch.long, device = device)
|
||||
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
|
||||
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
||||
values = self.relative_attention_bias(rp_bucket)
|
||||
return rearrange(values, 'i j h -> h i j')
|
||||
|
||||
# feedforward
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
@@ -208,7 +246,7 @@ class Attention(nn.Module):
|
||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
def forward(self, x, mask = None, attn_bias = None):
|
||||
b, n, device = *x.shape[:2], x.device
|
||||
|
||||
x = self.norm(x)
|
||||
@@ -225,6 +263,14 @@ class Attention(nn.Module):
|
||||
q = q * self.scale
|
||||
|
||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||
|
||||
# relative positional encoding (T5 style)
|
||||
|
||||
if exists(attn_bias):
|
||||
sim = sim + attn_bias
|
||||
|
||||
# masking
|
||||
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
|
||||
if exists(mask):
|
||||
@@ -237,10 +283,14 @@ class Attention(nn.Module):
|
||||
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
|
||||
sim = sim.masked_fill(causal_mask, max_neg_value)
|
||||
|
||||
# attention
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
@@ -260,7 +310,7 @@ class CausalTransformer(nn.Module):
|
||||
ff_dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
# todo - bring in rotary embeddings or alibi
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
@@ -276,8 +326,12 @@ class CausalTransformer(nn.Module):
|
||||
x,
|
||||
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
||||
):
|
||||
n, device = x.shape[1], x.device
|
||||
|
||||
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x, mask = mask) + x
|
||||
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
@@ -396,7 +450,7 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
@@ -722,7 +776,7 @@ class Unet(nn.Module):
|
||||
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
# time and image embeddings
|
||||
# time, image embeddings, and optional text encoding
|
||||
|
||||
cond_dim = default(cond_dim, dim)
|
||||
|
||||
@@ -739,9 +793,12 @@ class Unet(nn.Module):
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
self.text_to_cond = nn.LazyLinear(cond_dim)
|
||||
|
||||
# for classifier free guidance
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
||||
|
||||
# layers
|
||||
|
||||
@@ -806,6 +863,7 @@ class Unet(nn.Module):
|
||||
time_tokens = self.time_mlp(time)
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
|
||||
|
||||
# mask out image embedding depending on condition dropout
|
||||
# for classifier free guidance
|
||||
@@ -813,12 +871,31 @@ class Unet(nn.Module):
|
||||
image_tokens = self.image_to_cond(image_embed)
|
||||
|
||||
image_tokens = torch.where(
|
||||
rearrange(cond_prob_mask, 'b -> b 1 1'),
|
||||
cond_prob_mask,
|
||||
image_tokens,
|
||||
self.null_image_embed
|
||||
)
|
||||
|
||||
c = torch.cat((time_tokens, image_tokens), dim = -2) # c for condition
|
||||
# take care of text encodings (optional)
|
||||
|
||||
if exists(text_encodings):
|
||||
text_tokens = self.text_to_cond(text_encodings)
|
||||
text_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
text_tokens,
|
||||
self.null_text_embed
|
||||
)
|
||||
|
||||
# main conditioning tokens (c)
|
||||
|
||||
c = torch.cat((time_tokens, image_tokens), dim = -2)
|
||||
|
||||
# text and image conditioning tokens (mid_c)
|
||||
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
|
||||
|
||||
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
|
||||
|
||||
# go through the layers of the unet, down and up
|
||||
|
||||
hiddens = []
|
||||
|
||||
@@ -828,9 +905,9 @@ class Unet(nn.Module):
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block1(x, c)
|
||||
x = self.mid_block1(x, mid_c)
|
||||
x = self.mid_attn(x)
|
||||
x = self.mid_block2(x, c)
|
||||
x = self.mid_block2(x, mid_c)
|
||||
|
||||
for convnext, convnext2, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
@@ -864,7 +941,7 @@ class Decoder(nn.Module):
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
@@ -896,6 +973,10 @@ class Decoder(nn.Module):
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
def get_text_encodings(self, text):
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
return text_encodings[:, 1:]
|
||||
|
||||
def get_image_embed(self, image):
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
@@ -923,8 +1004,8 @@ class Decoder(nn.Module):
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
|
||||
def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.):
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale))
|
||||
|
||||
if clip_denoised:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
@@ -933,31 +1014,32 @@ class Decoder(nn.Module):
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, shape, image_embed, cond_scale = 1):
|
||||
def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device=device)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, image_embed, cond_scale = 1.):
|
||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||
batch_size = image_embed.shape[0]
|
||||
image_size = self.image_size
|
||||
channels = self.channels
|
||||
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
@@ -967,7 +1049,7 @@ class Decoder(nn.Module):
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
def p_losses(self, x_start, t, *, image_embed, noise = None):
|
||||
def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||
@@ -976,6 +1058,7 @@ class Decoder(nn.Module):
|
||||
x_noisy,
|
||||
t,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_drop_prob = self.cond_drop_prob
|
||||
)
|
||||
|
||||
@@ -988,14 +1071,16 @@ class Decoder(nn.Module):
|
||||
|
||||
return loss
|
||||
|
||||
def forward(self, image):
|
||||
def forward(self, image, text = None):
|
||||
b, device, img_size, = image.shape[0], image.device, self.image_size
|
||||
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
image_embed = self.get_image_embed(image)
|
||||
|
||||
loss = self.p_losses(image, times, image_embed = image_embed)
|
||||
image_embed = self.get_image_embed(image)
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
|
||||
loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings)
|
||||
return loss
|
||||
|
||||
# main class
|
||||
|
||||
Reference in New Issue
Block a user