have researcher explicitly state upfront whether to condition with text encodings in cascading ddpm decoder, have DALLE-2 class take care of passing in text if feature turned on

This commit is contained in:
Phil Wang
2022-04-26 09:47:09 -07:00
parent 7ba6357c05
commit 9878be760b
3 changed files with 18 additions and 5 deletions

View File

@@ -894,6 +894,7 @@ class Unet(nn.Module):
sparse_attn_window = 8, # window size for sparse attention
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
max_text_len = 256,
cond_on_image_embeds = False,
):
super().__init__()
@@ -944,7 +945,7 @@ class Unet(nn.Module):
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
# attention related params
@@ -1072,7 +1073,7 @@ class Unet(nn.Module):
text_tokens = torch.where(
cond_prob_mask,
text_tokens,
self.null_text_embed
self.null_text_embed[:, :text_tokens.shape[1]]
)
# main conditioning tokens (c)
@@ -1170,6 +1171,7 @@ class Decoder(nn.Module):
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = 0.1, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
):
super().__init__()
assert isinstance(clip, CLIP)
@@ -1178,6 +1180,8 @@ class Decoder(nn.Module):
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
self.condition_on_text_encodings = condition_on_text_encodings
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
@@ -1421,6 +1425,8 @@ class Decoder(nn.Module):
text_encodings = self.get_text_encodings(text) if exists(text) else None
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
img = None
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
@@ -1481,6 +1487,8 @@ class Decoder(nn.Module):
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size)
@@ -1508,7 +1516,9 @@ class DALLE2(nn.Module):
assert isinstance(decoder, Decoder)
self.prior = prior
self.decoder = decoder
self.prior_num_samples = prior_num_samples
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
@torch.no_grad()
@eval_decorator
@@ -1525,7 +1535,9 @@ class DALLE2(nn.Module):
text = tokenizer.tokenize(text).to(device)
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
text_cond = text if self.decoder_need_text_cond else None
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
if one_text:
return images[0]