mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
have researcher explicitly state upfront whether to condition with text encodings in cascading ddpm decoder, have DALLE-2 class take care of passing in text if feature turned on
This commit is contained in:
@@ -348,7 +348,8 @@ decoder = Decoder(
|
|||||||
image_sizes = (128, 256),
|
image_sizes = (128, 256),
|
||||||
clip = clip,
|
clip = clip,
|
||||||
timesteps = 100,
|
timesteps = 100,
|
||||||
cond_drop_prob = 0.2
|
cond_drop_prob = 0.2,
|
||||||
|
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
for unet_number in (1, 2):
|
for unet_number in (1, 2):
|
||||||
|
|||||||
@@ -894,6 +894,7 @@ class Unet(nn.Module):
|
|||||||
sparse_attn_window = 8, # window size for sparse attention
|
sparse_attn_window = 8, # window size for sparse attention
|
||||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||||
cond_on_text_encodings = False,
|
cond_on_text_encodings = False,
|
||||||
|
max_text_len = 256,
|
||||||
cond_on_image_embeds = False,
|
cond_on_image_embeds = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -944,7 +945,7 @@ class Unet(nn.Module):
|
|||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
|
|
||||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||||
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||||
|
|
||||||
# attention related params
|
# attention related params
|
||||||
|
|
||||||
@@ -1072,7 +1073,7 @@ class Unet(nn.Module):
|
|||||||
text_tokens = torch.where(
|
text_tokens = torch.where(
|
||||||
cond_prob_mask,
|
cond_prob_mask,
|
||||||
text_tokens,
|
text_tokens,
|
||||||
self.null_text_embed
|
self.null_text_embed[:, :text_tokens.shape[1]]
|
||||||
)
|
)
|
||||||
|
|
||||||
# main conditioning tokens (c)
|
# main conditioning tokens (c)
|
||||||
@@ -1170,6 +1171,7 @@ class Decoder(nn.Module):
|
|||||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||||
|
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
@@ -1178,6 +1180,8 @@ class Decoder(nn.Module):
|
|||||||
self.clip_image_size = clip.image_size
|
self.clip_image_size = clip.image_size
|
||||||
self.channels = clip.image_channels
|
self.channels = clip.image_channels
|
||||||
|
|
||||||
|
self.condition_on_text_encodings = condition_on_text_encodings
|
||||||
|
|
||||||
# automatically take care of ensuring that first unet is unconditional
|
# automatically take care of ensuring that first unet is unconditional
|
||||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||||
|
|
||||||
@@ -1421,6 +1425,8 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||||
|
|
||||||
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||||
|
|
||||||
img = None
|
img = None
|
||||||
|
|
||||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||||
@@ -1481,6 +1487,8 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||||
|
|
||||||
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||||
|
|
||||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||||
image = resize_image_to(image, target_image_size)
|
image = resize_image_to(image, target_image_size)
|
||||||
|
|
||||||
@@ -1508,7 +1516,9 @@ class DALLE2(nn.Module):
|
|||||||
assert isinstance(decoder, Decoder)
|
assert isinstance(decoder, Decoder)
|
||||||
self.prior = prior
|
self.prior = prior
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
|
||||||
self.prior_num_samples = prior_num_samples
|
self.prior_num_samples = prior_num_samples
|
||||||
|
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@eval_decorator
|
@eval_decorator
|
||||||
@@ -1525,7 +1535,9 @@ class DALLE2(nn.Module):
|
|||||||
text = tokenizer.tokenize(text).to(device)
|
text = tokenizer.tokenize(text).to(device)
|
||||||
|
|
||||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||||
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
|
|
||||||
|
text_cond = text if self.decoder_need_text_cond else None
|
||||||
|
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
|
||||||
|
|
||||||
if one_text:
|
if one_text:
|
||||||
return images[0]
|
return images[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user