mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
fix everything and make sure it runs end to end, document everything in readme for public
This commit is contained in:
@@ -1 +1,2 @@
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
from x_clip import CLIP
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import tqdm
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange
|
||||
from einops_exts import rearrange_many, repeat_many
|
||||
from einops import rearrange, repeat
|
||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
|
||||
# use x-clip
|
||||
|
||||
@@ -16,7 +22,9 @@ def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
def eval_decorator(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
@@ -27,6 +35,11 @@ def eval_decorator(fn):
|
||||
return out
|
||||
return inner
|
||||
|
||||
def is_list_str(x):
|
||||
if not isinstance(x, (list, tuple)):
|
||||
return False
|
||||
return all([type(el) == str for el in x])
|
||||
|
||||
# for controlling freezing of CLIP
|
||||
|
||||
def set_module_requires_grad_(module, requires_grad):
|
||||
@@ -43,6 +56,11 @@ def freeze_model_and_make_eval_(model):
|
||||
model.eval()
|
||||
freeze_all_layers_(model)
|
||||
|
||||
# tensor helpers
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1)
|
||||
|
||||
# classifier free guidance functions
|
||||
|
||||
def prob_mask_like(shape, prob, device):
|
||||
@@ -91,9 +109,16 @@ class RMSNorm(nn.Module):
|
||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||
return x * inv_norm * self.gamma * self.scale
|
||||
|
||||
class ChanRMSNorm(RMSNorm):
|
||||
def forward(self, x):
|
||||
squared_sum = (x ** 2).sum(dim = 1, keepdim = True)
|
||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
|
||||
|
||||
class PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
@@ -112,8 +137,8 @@ def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
*,
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
@@ -121,6 +146,7 @@ class Attention(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.causal = causal
|
||||
@@ -128,17 +154,17 @@ class Attention(nn.Module):
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
self.to_qkv = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
b, n, device = x.shape[:2], x.device
|
||||
b, n, device = *x.shape[:2], x.device
|
||||
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
|
||||
|
||||
q = rearrange(q, 'b n (h d) -> b h n d')
|
||||
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
||||
|
||||
# add null key / value for classifier free guidance in prior net
|
||||
|
||||
@@ -148,7 +174,7 @@ class Attention(nn.Module):
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
sim = einsum('b h i d, b j d -> b h i j')
|
||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
|
||||
if exists(mask):
|
||||
@@ -157,7 +183,8 @@ class Attention(nn.Module):
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
if self.causal:
|
||||
causal_mask = torch.ones((n, n), dtype = torch.bool, device = device).triu(1)
|
||||
i, j = sim.shape[-2:]
|
||||
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
|
||||
sim = sim.masked_fill(causal_mask, max_neg_value)
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||
@@ -214,7 +241,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
super().__init__()
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(**kwargs)
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
@@ -227,7 +254,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
return self.forward(x, **kwargs)
|
||||
|
||||
logits = self.forward(x, **kwargs)
|
||||
null_logits = self.forward(x, cond_prob_drop = 1., **kwargs)
|
||||
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
|
||||
return null_logits + (logits - null_logits) * cond_scale
|
||||
|
||||
def forward(
|
||||
@@ -248,9 +275,10 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 4), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.time_embeddings(diffusion_timesteps)
|
||||
time_embed = rearrange(time_embed, 'b d -> b 1 d')
|
||||
|
||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||
|
||||
@@ -268,7 +296,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch_size,), cond_prob_drop, device = device)
|
||||
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
|
||||
mask &= rearrange(cond_prob_mask, 'b -> b 1')
|
||||
|
||||
# attend
|
||||
@@ -288,19 +316,20 @@ class DiffusionPrior(nn.Module):
|
||||
*,
|
||||
clip,
|
||||
timesteps = 1000,
|
||||
cond_prob_drop = 0.2,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = 'l1',
|
||||
predict_x0 = True
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
self.clip = clip
|
||||
|
||||
self.net = net
|
||||
self.image_embed_dim = clip.dim_latent
|
||||
self.channels = clip.image_channels
|
||||
self.image_size = clip.image_size
|
||||
self.cond_prob_drop = cond_prob_drop
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
|
||||
self.predict_x0 = predict_x0
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
@@ -389,7 +418,7 @@ class DiffusionPrior(nn.Module):
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, image_embed, text_cond = None, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
@@ -420,18 +449,18 @@ class DiffusionPrior(nn.Module):
|
||||
text_cond = self.get_text_cond(text)
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
|
||||
text_embeds = text_cond['text_embeds']
|
||||
text_embeds = text_cond['text_embed']
|
||||
|
||||
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||
|
||||
text_image_sims = einsum('b r d, b r d -> b r')
|
||||
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
|
||||
top_sim_indices = text_image_sims.topk(k = 1).indices
|
||||
|
||||
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b d', d = image_embed_dim)
|
||||
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
|
||||
|
||||
top_image_embeds = image_embeds.gather(1, top_sim_indices)
|
||||
return top_image_embeds
|
||||
return rearrange(top_image_embeds, 'b 1 d -> b d')
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
@@ -442,14 +471,14 @@ class DiffusionPrior(nn.Module):
|
||||
)
|
||||
|
||||
def p_losses(self, image_embed, t, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
|
||||
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise)
|
||||
|
||||
x_recon = self.net(
|
||||
image_embed_noisy,
|
||||
t,
|
||||
cond_prob_drop = self.cond_prob_drop,
|
||||
cond_drop_prob = self.cond_drop_prob,
|
||||
**text_cond
|
||||
)
|
||||
|
||||
@@ -472,7 +501,7 @@ class DiffusionPrior(nn.Module):
|
||||
image_embed = self.get_image_embed(image)
|
||||
text_cond = self.get_text_cond(text)
|
||||
|
||||
loss = self.p_losses(x, times, image_embed = image_embed, text_cond = text_cond, *args, **kwargs)
|
||||
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
||||
return loss
|
||||
|
||||
# decoder
|
||||
@@ -519,7 +548,7 @@ class ConvNextBlock(nn.Module):
|
||||
|
||||
inner_dim = int(dim_out * mult)
|
||||
self.net = nn.Sequential(
|
||||
RMSNorm(dim) if norm else nn.Identity(),
|
||||
ChanRMSNorm(dim) if norm else nn.Identity(),
|
||||
nn.Conv2d(dim, inner_dim, 3, padding = 1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
|
||||
@@ -538,21 +567,6 @@ class ConvNextBlock(nn.Module):
|
||||
h = self.net(h)
|
||||
return h + self.res_conv(x)
|
||||
|
||||
class EinopsToAndFrom(nn.Module):
|
||||
def __init__(self, from_einops, to_einops, fn):
|
||||
super().__init__()
|
||||
self.from_einops = from_einops
|
||||
self.to_einops = to_einops
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
shape = x.shape
|
||||
reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))
|
||||
x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
|
||||
x = self.fn(x, **kwargs)
|
||||
x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
|
||||
return x
|
||||
|
||||
class Unet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -597,6 +611,7 @@ class Unet(nn.Module):
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim)))
|
||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
@@ -627,7 +642,7 @@ class Unet(nn.Module):
|
||||
return self.forward(x, **kwargs)
|
||||
|
||||
logits = self.forward(x, **kwargs)
|
||||
null_logits = self.forward(x, cond_prob_drop = 1., **kwargs)
|
||||
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
|
||||
return null_logits + (logits - null_logits) * cond_scale
|
||||
|
||||
def forward(
|
||||
@@ -637,11 +652,12 @@ class Unet(nn.Module):
|
||||
*,
|
||||
image_embed,
|
||||
text_encodings = None,
|
||||
cond_prob_drop = 0.
|
||||
cond_drop_prob = 0.
|
||||
):
|
||||
batch_size, device = x.shape[0], x.device
|
||||
t = self.time_mlp(time)
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch_size,), cond_prob_drop, device = device)
|
||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
||||
|
||||
# mask out image embedding depending on condition dropout
|
||||
# for classifier free guidance
|
||||
@@ -652,7 +668,7 @@ class Unet(nn.Module):
|
||||
rearrange(self.null_image_embed, 'd -> 1 d')
|
||||
)
|
||||
|
||||
cond = torch.cat((t, image_embed), dim = -1)
|
||||
t = torch.cat((t, image_embed), dim = -1)
|
||||
|
||||
hiddens = []
|
||||
|
||||
@@ -663,7 +679,7 @@ class Unet(nn.Module):
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block1(x, t)
|
||||
x = self.attn(x)
|
||||
x = self.mid_attn(x)
|
||||
x = self.mid_block2(x, t)
|
||||
|
||||
for convnext, convnext2, upsample in self.ups:
|
||||
@@ -681,17 +697,18 @@ class Decoder(nn.Module):
|
||||
*,
|
||||
clip,
|
||||
timesteps = 1000,
|
||||
cond_prob_drop = 0.2,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = 'l1'
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
self.clip = clip
|
||||
|
||||
self.net = net
|
||||
self.channels = clip.image_channels
|
||||
self.image_size = clip.image_size
|
||||
self.cond_prob_drop = cond_prob_drop
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
|
||||
@@ -768,7 +785,7 @@ class Decoder(nn.Module):
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, image_embed, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, clip_denoised = clip_denoised)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, clip_denoised = clip_denoised)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
@@ -800,7 +817,7 @@ class Decoder(nn.Module):
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
def p_losses(self, x_start, image_embed, t, noise = None):
|
||||
def p_losses(self, x_start, t, *, image_embed, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||
@@ -809,7 +826,7 @@ class Decoder(nn.Module):
|
||||
x_noisy,
|
||||
t,
|
||||
image_embed = image_embed,
|
||||
cond_prob_drop = self.cond_prob_drop
|
||||
cond_drop_prob = self.cond_drop_prob
|
||||
)
|
||||
|
||||
if self.loss_type == 'l1':
|
||||
@@ -821,14 +838,14 @@ class Decoder(nn.Module):
|
||||
|
||||
return loss
|
||||
|
||||
def forward(self, image, *args, **kwargs):
|
||||
def forward(self, image):
|
||||
b, device, img_size, = image.shape[0], image.device, self.image_size
|
||||
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
image_embed = self.get_image_embed(image)
|
||||
|
||||
loss = self.p_losses(x, times, image_embed = image_embed, *args, **kwargs)
|
||||
loss = self.p_losses(image, times, image_embed = image_embed)
|
||||
return loss
|
||||
|
||||
# main class
|
||||
@@ -839,23 +856,27 @@ class DALLE2(nn.Module):
|
||||
*,
|
||||
prior,
|
||||
decoder,
|
||||
tokenizer = None
|
||||
prior_num_samples = 2
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(prior), DiffusionPrior
|
||||
assert isinstance(decoder), Decoder
|
||||
self.tokenizer = tokenizer
|
||||
assert isinstance(prior, DiffusionPrior)
|
||||
assert isinstance(decoder, Decoder)
|
||||
self.prior = prior.eval()
|
||||
self.decoder = decoder.eval()
|
||||
self.prior_num_samples = prior_num_samples
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
text
|
||||
):
|
||||
if isinstance(text, str):
|
||||
assert exists(self.tokenizer), 'tokenizer must be passed in if you were to pass in the text as a string'
|
||||
text = self.tokenizer.encode(text)
|
||||
device = next(self.parameters()).device
|
||||
|
||||
image_embed = prior.sample(text, num_samples_per_batch = 2)
|
||||
images = decoder.sample(image_embed)
|
||||
if isinstance(text, str) or is_list_str(text):
|
||||
text = [text] if not isinstance(text, (list, tuple)) else text
|
||||
text = tokenizer.tokenize(text).to(device)
|
||||
|
||||
print(text.shape, type(text))
|
||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||
images = self.decoder.sample(image_embed)
|
||||
return images
|
||||
|
||||
Reference in New Issue
Block a user