allow for decoder conditioning with the text encodings from CLIP, if it is passed in. use lazy linear to avoid researchers having to worry about text encoding dimensions, but remove later if it does not work well

This commit is contained in:
Phil Wang
2022-04-14 11:46:45 -07:00
parent 69e822b7f8
commit 9f55c24db6
3 changed files with 50 additions and 19 deletions

View File

@@ -276,7 +276,7 @@ decoder = Decoder(
cond_drop_prob = 0.2 cond_drop_prob = 0.2
).cuda() ).cuda()
loss = decoder(images) loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward() loss.backward()
# do above for many steps # do above for many steps

View File

@@ -722,7 +722,7 @@ class Unet(nn.Module):
dims = [channels, *map(lambda m: dim * m, dim_mults)] dims = [channels, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:])) in_out = list(zip(dims[:-1], dims[1:]))
# time and image embeddings # time, image embeddings, and optional text encoding
cond_dim = default(cond_dim, dim) cond_dim = default(cond_dim, dim)
@@ -739,9 +739,12 @@ class Unet(nn.Module):
Rearrange('b (n d) -> b n d', n = num_image_tokens) Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity() ) if image_embed_dim != cond_dim else nn.Identity()
self.text_to_cond = nn.LazyLinear(cond_dim)
# for classifier free guidance # for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
# layers # layers
@@ -806,6 +809,7 @@ class Unet(nn.Module):
time_tokens = self.time_mlp(time) time_tokens = self.time_mlp(time)
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device) cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
# mask out image embedding depending on condition dropout # mask out image embedding depending on condition dropout
# for classifier free guidance # for classifier free guidance
@@ -813,12 +817,31 @@ class Unet(nn.Module):
image_tokens = self.image_to_cond(image_embed) image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where( image_tokens = torch.where(
rearrange(cond_prob_mask, 'b -> b 1 1'), cond_prob_mask,
image_tokens, image_tokens,
self.null_image_embed self.null_image_embed
) )
c = torch.cat((time_tokens, image_tokens), dim = -2) # c for condition # take care of text encodings (optional)
if exists(text_encodings):
text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where(
cond_prob_mask,
text_tokens,
self.null_text_embed
)
# main conditioning tokens (c)
c = torch.cat((time_tokens, image_tokens), dim = -2)
# text and image conditioning tokens (mid_c)
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
# go through the layers of the unet, down and up
hiddens = [] hiddens = []
@@ -828,9 +851,9 @@ class Unet(nn.Module):
hiddens.append(x) hiddens.append(x)
x = downsample(x) x = downsample(x)
x = self.mid_block1(x, c) x = self.mid_block1(x, mid_c)
x = self.mid_attn(x) x = self.mid_attn(x)
x = self.mid_block2(x, c) x = self.mid_block2(x, mid_c)
for convnext, convnext2, upsample in self.ups: for convnext, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1) x = torch.cat((x, hiddens.pop()), dim=1)
@@ -896,6 +919,10 @@ class Decoder(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:]
def get_image_embed(self, image): def get_image_embed(self, image):
image_encoding = self.clip.visual_transformer(image) image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0] image_cls = image_encoding[:, 0]
@@ -923,8 +950,8 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.): def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale)) x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale))
if clip_denoised: if clip_denoised:
x_recon.clamp_(-1., 1.) x_recon.clamp_(-1., 1.)
@@ -933,31 +960,32 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.no_grad()
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False): def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised) model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise) noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad() @torch.no_grad()
def p_sample_loop(self, shape, image_embed, cond_scale = 1): def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1):
device = self.betas.device device = self.betas.device
b = shape[0] b = shape[0]
img = torch.randn(shape, device=device) img = torch.randn(shape, device=device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale) img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
return img return img
@torch.no_grad() @torch.no_grad()
def sample(self, image_embed, cond_scale = 1.): def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0] batch_size = image_embed.shape[0]
image_size = self.image_size image_size = self.image_size
channels = self.channels channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale) text_encodings = self.get_text_encodings(text) if exists(text) else None
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
def q_sample(self, x_start, t, noise=None): def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
@@ -967,7 +995,7 @@ class Decoder(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
) )
def p_losses(self, x_start, t, *, image_embed, noise = None): def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise) x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
@@ -976,6 +1004,7 @@ class Decoder(nn.Module):
x_noisy, x_noisy,
t, t,
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings,
cond_drop_prob = self.cond_drop_prob cond_drop_prob = self.cond_drop_prob
) )
@@ -988,14 +1017,16 @@ class Decoder(nn.Module):
return loss return loss
def forward(self, image): def forward(self, image, text = None):
b, device, img_size, = image.shape[0], image.device, self.image_size b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels) check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
loss = self.p_losses(image, times, image_embed = image_embed) image_embed = self.get_image_embed(image)
text_encodings = self.get_text_encodings(text) if exists(text) else None
loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings)
return loss return loss
# main class # main class

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.12', version = '0.0.14',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',