Compare commits

..

6 Commits

3 changed files with 186 additions and 65 deletions

View File

@@ -101,7 +101,7 @@ clip = CLIP(
unet = Unet(
dim = 128,
image_embed_dim = 512,
time_dim = 128,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
@@ -264,7 +264,7 @@ loss.backward()
unet = Unet(
dim = 128,
image_embed_dim = 512,
time_dim = 128,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
@@ -276,7 +276,7 @@ decoder = Decoder(
cond_drop_prob = 0.2
).cuda()
loss = decoder(images)
loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
@@ -319,7 +319,7 @@ Offer training wrappers
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
- [ ] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [ ] train on a toy task, offer in colab

View File

@@ -118,14 +118,13 @@ class ChanRMSNorm(RMSNorm):
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.norm = RMSNorm(dim)
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs) + x
return self.fn(x, **kwargs) + x
# mlp
@@ -164,12 +163,21 @@ class MLP(nn.Module):
# feedforward
def FeedForward(dim, mult = 4, dropout = 0.):
class SwiGLU(nn.Module):
""" used successfully in https://arxiv.org/abs/2204.0231 """
def forward(self, x):
x, gate = x.chunk(2, dim = -1)
return x * F.silu(gate)
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
""" post-activation norm https://arxiv.org/abs/2110.09456 """
inner_dim = int(mult * dim)
return nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, inner_dim, bias = False),
nn.GELU(),
nn.Linear(dim, inner_dim * 2, bias = False),
SwiGLU(),
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias = False)
)
@@ -231,6 +239,7 @@ class Attention(nn.Module):
sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
out = einsum('b h i j, b j d -> b h i d', attn, v)
@@ -320,8 +329,8 @@ class DiffusionPriorNetwork(nn.Module):
# but let's just do it right
if exists(mask):
all_masked_out = mask.any(dim = -1)
mask = torch.cat((mask, rearrange(all_masked_out, 'b -> b 1')), dim = 1)
not_all_masked_out = mask.any(dim = -1)
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
@@ -562,6 +571,17 @@ def Upsample(dim):
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)
class Blur(nn.Module):
def __init__(self):
super().__init__()
filt = torch.Tensor([1, 2, 1])
self.register_buffer('filt', filt)
def forward(self, x):
filt = self.filt
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
return filter2d(x, filt, normalized = True)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -589,10 +609,17 @@ class ConvNextBlock(nn.Module):
super().__init__()
need_projection = dim != dim_out
self.mlp = nn.Sequential(
nn.GELU(),
nn.Linear(cond_dim, dim)
) if exists(cond_dim) else None
self.cross_attn = None
if exists(cond_dim):
self.cross_attn = EinopsToAndFrom(
'b c h w',
'b (h w) c',
CrossAttention(
dim = dim,
context_dim = cond_dim
)
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
@@ -609,21 +636,82 @@ class ConvNextBlock(nn.Module):
def forward(self, x, cond = None):
h = self.ds_conv(x)
if exists(self.mlp):
if exists(self.cross_attn):
assert exists(cond)
condition = self.mlp(cond)
h = h + rearrange(condition, 'b c -> b c 1 1')
h = self.cross_attn(h, context = cond) + h
h = self.net(h)
return h + self.res_conv(x)
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim = None,
dim_head = 64,
heads = 8,
dropout = 0.,
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
context_dim = default(context_dim, dim)
self.norm = RMSNorm(dim)
self.norm_context = RMSNorm(context_dim)
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
# add null key / value for classifier free guidance in prior net
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Unet(nn.Module):
def __init__(
self,
dim,
*,
image_embed_dim,
time_dim = None,
cond_dim = None,
num_image_tokens = 4,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
@@ -634,18 +722,31 @@ class Unet(nn.Module):
dims = [channels, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
time_dim = default(time_dim, dim)
# time, image embeddings, and optional text encoding
cond_dim = default(cond_dim, dim)
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
nn.Linear(dim * 4, cond_dim),
Rearrange('b d -> b 1 d')
)
self.null_image_embed = nn.Parameter(torch.randn(image_embed_dim))
self.image_to_cond = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity()
cond_dim = time_dim + image_embed_dim
self.text_to_cond = nn.LazyLinear(cond_dim)
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
@@ -655,7 +756,7 @@ class Unet(nn.Module):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, cond_dim = cond_dim, norm = ind != 0),
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
@@ -663,7 +764,7 @@ class Unet(nn.Module):
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim)))
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim)))
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
@@ -705,52 +806,63 @@ class Unet(nn.Module):
cond_drop_prob = 0.
):
batch_size, device = x.shape[0], x.device
t = self.time_mlp(time)
time_tokens = self.time_mlp(time)
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
# mask out image embedding depending on condition dropout
# for classifier free guidance
image_embed = torch.where(
rearrange(cond_prob_mask, 'b -> b 1'),
image_embed,
rearrange(self.null_image_embed, 'd -> 1 d')
image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where(
cond_prob_mask,
image_tokens,
self.null_image_embed
)
t = torch.cat((t, image_embed), dim = -1)
# take care of text encodings (optional)
if exists(text_encodings):
text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where(
cond_prob_mask,
text_tokens,
self.null_text_embed
)
# main conditioning tokens (c)
c = torch.cat((time_tokens, image_tokens), dim = -2)
# text and image conditioning tokens (mid_c)
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
# go through the layers of the unet, down and up
hiddens = []
for convnext, convnext2, downsample in self.downs:
x = convnext(x, t)
x = convnext2(x, t)
x = convnext(x, c)
x = convnext2(x, c)
hiddens.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_block1(x, mid_c)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
x = self.mid_block2(x, mid_c)
for convnext, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, t)
x = convnext2(x, t)
x = convnext(x, c)
x = convnext2(x, c)
x = upsample(x)
return self.final_conv(x)
class Blur(nn.Module):
def __init__(self):
super().__init__()
filt = torch.Tensor([1, 2, 1])
self.register_buffer('filt', filt)
def forward(self, x):
filt = self.filt
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
return filter2d(x, filt, normalized = True)
class Decoder(nn.Module):
def __init__(
self,
@@ -807,6 +919,10 @@ class Decoder(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:]
def get_image_embed(self, image):
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
@@ -834,8 +950,8 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale))
if clip_denoised:
x_recon.clamp_(-1., 1.)
@@ -844,31 +960,32 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, image_embed, cond_scale = 1):
def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
return img
@torch.no_grad()
def sample(self, image_embed, cond_scale = 1.):
def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0]
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
text_encodings = self.get_text_encodings(text) if exists(text) else None
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -878,7 +995,7 @@ class Decoder(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, t, *, image_embed, noise = None):
def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
@@ -887,6 +1004,7 @@ class Decoder(nn.Module):
x_noisy,
t,
image_embed = image_embed,
text_encodings = text_encodings,
cond_drop_prob = self.cond_drop_prob
)
@@ -899,14 +1017,16 @@ class Decoder(nn.Module):
return loss
def forward(self, image):
def forward(self, image, text = None):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
loss = self.p_losses(image, times, image_embed = image_embed)
image_embed = self.get_image_embed(image)
text_encodings = self.get_text_encodings(text) if exists(text) else None
loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings)
return loss
# main class
@@ -922,11 +1042,12 @@ class DALLE2(nn.Module):
super().__init__()
assert isinstance(prior, DiffusionPrior)
assert isinstance(decoder, Decoder)
self.prior = prior.eval()
self.decoder = decoder.eval()
self.prior = prior
self.decoder = decoder
self.prior_num_samples = prior_num_samples
@torch.no_grad()
@eval_decorator
def forward(
self,
text,

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.9',
version = '0.0.14',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',