remove text masking altogether in favor of deriving from text encodings (padded text encodings must be pad value of 0.)

This commit is contained in:
Phil Wang
2022-07-12 15:40:31 -07:00
parent bb3ff0ac67
commit e76e89f9eb
4 changed files with 28 additions and 41 deletions

View File

@@ -421,7 +421,7 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
## Training on Preprocessed CLIP Embeddings ## Training on Preprocessed CLIP Embeddings
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask` It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings`
Working example below Working example below

View File

@@ -165,7 +165,7 @@ def unnormalize_zero_to_one(normed_img):
# clip related adapters # clip related adapters
EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 'text_mask']) EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings'])
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings']) EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
class BaseClipAdapter(nn.Module): class BaseClipAdapter(nn.Module):
@@ -226,7 +226,7 @@ class XClipAdapter(BaseClipAdapter):
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:] text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
text_embed = self.clip.to_text_latent(text_cls) text_embed = self.clip.to_text_latent(text_cls)
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.) text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return EmbeddedText(l2norm(text_embed), text_encodings, text_mask) return EmbeddedText(l2norm(text_embed), text_encodings)
@torch.no_grad() @torch.no_grad()
def embed_image(self, image): def embed_image(self, image):
@@ -262,7 +262,7 @@ class CoCaAdapter(BaseClipAdapter):
text_mask = text != 0 text_mask = text != 0
text_embed, text_encodings = self.clip.embed_text(text) text_embed, text_encodings = self.clip.embed_text(text)
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.) text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return EmbeddedText(text_embed, text_encodings, text_mask) return EmbeddedText(text_embed, text_encodings)
@torch.no_grad() @torch.no_grad()
def embed_image(self, image): def embed_image(self, image):
@@ -323,7 +323,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
text_encodings = self.text_encodings text_encodings = self.text_encodings
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.) text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
del self.text_encodings del self.text_encodings
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask) return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())
@torch.no_grad() @torch.no_grad()
def embed_image(self, image): def embed_image(self, image):
@@ -871,7 +871,6 @@ class DiffusionPriorNetwork(nn.Module):
if not exists(text_encodings): if not exists(text_encodings):
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype) text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
if not exists(mask) or mask.numel() == 0:
mask = torch.any(text_encodings != 0., dim = -1) mask = torch.any(text_encodings != 0., dim = -1)
# classifier free guidance # classifier free guidance
@@ -889,7 +888,6 @@ class DiffusionPriorNetwork(nn.Module):
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right # but let's just do it right
if exists(mask):
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
@@ -1152,12 +1150,12 @@ class DiffusionPrior(nn.Module):
batch_size = text.shape[0] batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim image_embed_dim = self.image_embed_dim
text_embed, text_encodings, text_mask = self.clip.embed_text(text) text_embed, text_encodings = self.clip.embed_text(text)
text_cond = dict(text_embed = text_embed) text_cond = dict(text_embed = text_embed)
if self.condition_on_text_encodings: if self.condition_on_text_encodings:
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} text_cond = {**text_cond, 'text_encodings': text_encodings}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps) image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
@@ -1185,7 +1183,6 @@ class DiffusionPrior(nn.Module):
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
image_embed = None, image_embed = None,
text_encodings = None, # as well as CLIP text encodings text_encodings = None, # as well as CLIP text encodings
text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity
*args, *args,
**kwargs **kwargs
): ):
@@ -1199,13 +1196,13 @@ class DiffusionPrior(nn.Module):
# calculate text conditionings, based on what is passed in # calculate text conditionings, based on what is passed in
if exists(text): if exists(text):
text_embed, text_encodings, text_mask = self.clip.embed_text(text) text_embed, text_encodings = self.clip.embed_text(text)
text_cond = dict(text_embed = text_embed) text_cond = dict(text_embed = text_embed)
if self.condition_on_text_encodings: if self.condition_on_text_encodings:
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified' assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} text_cond = {**text_cond, 'text_encodings': text_encodings}
# timestep conditioning from ddpm # timestep conditioning from ddpm
@@ -1744,7 +1741,6 @@ class Unet(nn.Module):
image_embed, image_embed,
lowres_cond_img = None, lowres_cond_img = None,
text_encodings = None, text_encodings = None,
text_mask = None,
image_cond_drop_prob = 0., image_cond_drop_prob = 0.,
text_cond_drop_prob = 0., text_cond_drop_prob = 0.,
blur_sigma = None, blur_sigma = None,
@@ -1818,7 +1814,6 @@ class Unet(nn.Module):
if exists(text_encodings) and self.cond_on_text_encodings: if exists(text_encodings) and self.cond_on_text_encodings:
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.' assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
if not exists(text_mask) or text_mask.numel() == 0:
text_mask = torch.any(text_encodings != 0., dim = -1) text_mask = torch.any(text_encodings != 0., dim = -1)
text_tokens = self.text_to_cond(text_encodings) text_tokens = self.text_to_cond(text_encodings)
@@ -2218,10 +2213,10 @@ class Decoder(nn.Module):
x = x.clamp(-s, s) / s x = x.clamp(-s, s) / s
return x return x
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None): def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)) pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img))
if learned_variance: if learned_variance:
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1) pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
@@ -2253,16 +2248,16 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.no_grad()
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True): def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance) model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance)
noise = torch.randn_like(x) noise = torch.randn_like(x)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad() @torch.no_grad()
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False): def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False):
device = self.device device = self.device
b = shape[0] b = shape[0]
@@ -2278,7 +2273,6 @@ class Decoder(nn.Module):
torch.full((b,), i, device = device, dtype = torch.long), torch.full((b,), i, device = device, dtype = torch.long),
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
text_mask = text_mask,
cond_scale = cond_scale, cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
@@ -2291,7 +2285,7 @@ class Decoder(nn.Module):
return unnormalize_img return unnormalize_img
@torch.no_grad() @torch.no_grad()
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False): def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False):
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1] times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
@@ -2307,7 +2301,7 @@ class Decoder(nn.Module):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long) time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
if learned_variance: if learned_variance:
pred, _ = pred.chunk(2, dim = 1) pred, _ = pred.chunk(2, dim = 1)
@@ -2346,7 +2340,7 @@ class Decoder(nn.Module):
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs) return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False): def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1] # normalize to [-1, 1]
@@ -2364,7 +2358,6 @@ class Decoder(nn.Module):
times, times,
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
text_mask = text_mask,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
image_cond_drop_prob = self.image_cond_drop_prob, image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob, text_cond_drop_prob = self.text_cond_drop_prob,
@@ -2424,7 +2417,6 @@ class Decoder(nn.Module):
self, self,
image_embed = None, image_embed = None,
text = None, text = None,
text_mask = None,
text_encodings = None, text_encodings = None,
batch_size = 1, batch_size = 1,
cond_scale = 1., cond_scale = 1.,
@@ -2438,7 +2430,7 @@ class Decoder(nn.Module):
if exists(text) and not exists(text_encodings) and not self.unconditional: if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip) assert exists(self.clip)
_, text_encodings, text_mask = self.clip.embed_text(text) _, text_encodings = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
@@ -2468,7 +2460,6 @@ class Decoder(nn.Module):
shape, shape,
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
text_mask = text_mask,
cond_scale = cond_scale, cond_scale = cond_scale,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
learned_variance = learned_variance, learned_variance = learned_variance,
@@ -2492,7 +2483,6 @@ class Decoder(nn.Module):
text = None, text = None,
image_embed = None, image_embed = None,
text_encodings = None, text_encodings = None,
text_mask = None,
unet_number = None, unet_number = None,
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
): ):
@@ -2521,7 +2511,7 @@ class Decoder(nn.Module):
if exists(text) and not exists(text_encodings) and not self.unconditional: if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder' assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
_, text_encodings, text_mask = self.clip.embed_text(text) _, text_encodings = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
@@ -2544,7 +2534,7 @@ class Decoder(nn.Module):
image = vae.encode(image) image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img) lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler) losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
if not return_lowres_cond_image: if not return_lowres_cond_image:
return losses return losses

View File

@@ -1 +1 @@
__version__ = '0.21.3' __version__ = '0.22.0'

View File

@@ -126,9 +126,9 @@ def report_cosine_sims(
# we are text conditioned, we produce an embedding from the tokenized text # we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned: if text_conditioned:
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data) text_embedding, text_encodings = trainer.embed_text(text_data)
text_cond = dict( text_cond = dict(
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask text_embed=text_embedding, text_encodings=text_encodings
) )
else: else:
text_embedding = text_data text_embedding = text_data
@@ -146,15 +146,12 @@ def report_cosine_sims(
if text_conditioned: if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx] text_encodings_shuffled = text_encodings[rolled_idx]
text_mask_shuffled = text_mask[rolled_idx]
else: else:
text_encodings_shuffled = None text_encodings_shuffled = None
text_mask_shuffled = None
text_cond_shuffled = dict( text_cond_shuffled = dict(
text_embed=text_embed_shuffled, text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled, text_encodings=text_encodings_shuffled
mask=text_mask_shuffled,
) )
# prepare the text embedding # prepare the text embedding