diff --git a/README.md b/README.md index adae23f..9022190 100644 --- a/README.md +++ b/README.md @@ -276,7 +276,7 @@ decoder = Decoder( cond_drop_prob = 0.2 ).cuda() -loss = decoder(images) +loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much loss.backward() # do above for many steps diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index eb03cd7..f717d1e 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -722,7 +722,7 @@ class Unet(nn.Module): dims = [channels, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) - # time and image embeddings + # time, image embeddings, and optional text encoding cond_dim = default(cond_dim, dim) @@ -739,9 +739,12 @@ class Unet(nn.Module): Rearrange('b (n d) -> b n d', n = num_image_tokens) ) if image_embed_dim != cond_dim else nn.Identity() + self.text_to_cond = nn.LazyLinear(cond_dim) + # for classifier free guidance self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) + self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim)) # layers @@ -806,6 +809,7 @@ class Unet(nn.Module): time_tokens = self.time_mlp(time) cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device) + cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1') # mask out image embedding depending on condition dropout # for classifier free guidance @@ -813,12 +817,31 @@ class Unet(nn.Module): image_tokens = self.image_to_cond(image_embed) image_tokens = torch.where( - rearrange(cond_prob_mask, 'b -> b 1 1'), + cond_prob_mask, image_tokens, self.null_image_embed ) - c = torch.cat((time_tokens, image_tokens), dim = -2) # c for condition + # take care of text encodings (optional) + + if exists(text_encodings): + text_tokens = self.text_to_cond(text_encodings) + text_tokens = torch.where( + cond_prob_mask, + text_tokens, + self.null_text_embed + ) + + # main conditioning tokens (c) + + c = torch.cat((time_tokens, image_tokens), dim = -2) + + # text and image conditioning tokens (mid_c) + # to save on compute, only do cross attention based conditioning on the inner most layers of the Unet + + mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2) + + # go through the layers of the unet, down and up hiddens = [] @@ -828,9 +851,9 @@ class Unet(nn.Module): hiddens.append(x) x = downsample(x) - x = self.mid_block1(x, c) + x = self.mid_block1(x, mid_c) x = self.mid_attn(x) - x = self.mid_block2(x, c) + x = self.mid_block2(x, mid_c) for convnext, convnext2, upsample in self.ups: x = torch.cat((x, hiddens.pop()), dim=1) @@ -896,6 +919,10 @@ class Decoder(nn.Module): self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) + def get_text_encodings(self, text): + text_encodings = self.clip.text_transformer(text) + return text_encodings[:, 1:] + def get_image_embed(self, image): image_encoding = self.clip.visual_transformer(image) image_cls = image_encoding[:, 0] @@ -923,8 +950,8 @@ class Decoder(nn.Module): posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped - def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.): - x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale)) + def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.): + x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)) if clip_denoised: x_recon.clamp_(-1., 1.) @@ -933,31 +960,32 @@ class Decoder(nn.Module): return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False): + def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False): b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised) + model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised) noise = noise_like(x.shape, device, repeat_noise) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def p_sample_loop(self, shape, image_embed, cond_scale = 1): + def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1): device = self.betas.device b = shape[0] img = torch.randn(shape, device=device) for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): - img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale) + img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale) return img @torch.no_grad() - def sample(self, image_embed, cond_scale = 1.): + def sample(self, image_embed, text = None, cond_scale = 1.): batch_size = image_embed.shape[0] image_size = self.image_size channels = self.channels - return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale) + text_encodings = self.get_text_encodings(text) if exists(text) else None + return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) @@ -967,7 +995,7 @@ class Decoder(nn.Module): extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) - def p_losses(self, x_start, t, *, image_embed, noise = None): + def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None): noise = default(noise, lambda: torch.randn_like(x_start)) x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise) @@ -976,6 +1004,7 @@ class Decoder(nn.Module): x_noisy, t, image_embed = image_embed, + text_encodings = text_encodings, cond_drop_prob = self.cond_drop_prob ) @@ -988,14 +1017,16 @@ class Decoder(nn.Module): return loss - def forward(self, image): + def forward(self, image, text = None): b, device, img_size, = image.shape[0], image.device, self.image_size check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels) times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) - image_embed = self.get_image_embed(image) - loss = self.p_losses(image, times, image_embed = image_embed) + image_embed = self.get_image_embed(image) + text_encodings = self.get_text_encodings(text) if exists(text) else None + + loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings) return loss # main class diff --git a/setup.py b/setup.py index 0f15d48..9abd021 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.12', + version = '0.0.14', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',