mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-15 11:14:26 +01:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f19c99ecb0 | ||
|
|
721a444686 | ||
|
|
63450b466d | ||
|
|
20e7eb5a9b |
19
README.md
19
README.md
@@ -47,7 +47,7 @@ clip = CLIP(
|
|||||||
use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
|
use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
|
||||||
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
|
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
|
||||||
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
|
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
|
||||||
use_visual_ssl = True, # whether to do self supervised learning on iages
|
use_visual_ssl = True, # whether to do self supervised learning on images
|
||||||
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
|
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
|
||||||
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
|
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
|
||||||
text_ssl_loss_weight = 0.05, # weight for text MLM loss
|
text_ssl_loss_weight = 0.05, # weight for text MLM loss
|
||||||
@@ -110,7 +110,8 @@ decoder = Decoder(
|
|||||||
unet = unet,
|
unet = unet,
|
||||||
clip = clip,
|
clip = clip,
|
||||||
timesteps = 100,
|
timesteps = 100,
|
||||||
cond_drop_prob = 0.2
|
image_cond_drop_prob = 0.1,
|
||||||
|
text_cond_drop_prob = 0.5
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
# mock images (get a lot of this)
|
# mock images (get a lot of this)
|
||||||
@@ -229,7 +230,8 @@ decoder = Decoder(
|
|||||||
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||||
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
|
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
|
||||||
timesteps = 1000,
|
timesteps = 1000,
|
||||||
cond_drop_prob = 0.2
|
image_cond_drop_prob = 0.1,
|
||||||
|
text_cond_drop_prob = 0.5
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
# mock images (get a lot of this)
|
# mock images (get a lot of this)
|
||||||
@@ -348,7 +350,8 @@ decoder = Decoder(
|
|||||||
image_sizes = (128, 256),
|
image_sizes = (128, 256),
|
||||||
clip = clip,
|
clip = clip,
|
||||||
timesteps = 100,
|
timesteps = 100,
|
||||||
cond_drop_prob = 0.2,
|
image_cond_drop_prob = 0.1,
|
||||||
|
text_cond_drop_prob = 0.5,
|
||||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
@@ -558,7 +561,8 @@ decoder = Decoder(
|
|||||||
image_sizes = (128, 256),
|
image_sizes = (128, 256),
|
||||||
clip = clip,
|
clip = clip,
|
||||||
timesteps = 100,
|
timesteps = 100,
|
||||||
cond_drop_prob = 0.2,
|
image_cond_drop_prob = 0.1,
|
||||||
|
text_cond_drop_prob = 0.5,
|
||||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
@@ -616,7 +620,7 @@ clip = CLIP(
|
|||||||
# 3 unets for the decoder (a la cascading DDPM)
|
# 3 unets for the decoder (a la cascading DDPM)
|
||||||
|
|
||||||
# first two unets are doing latent diffusion
|
# first two unets are doing latent diffusion
|
||||||
# vqgan-vae must be trained before hand
|
# vqgan-vae must be trained beforehand
|
||||||
|
|
||||||
vae1 = VQGanVAE(
|
vae1 = VQGanVAE(
|
||||||
dim = 32,
|
dim = 32,
|
||||||
@@ -669,7 +673,8 @@ decoder = Decoder(
|
|||||||
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||||
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
|
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
|
||||||
timesteps = 100,
|
timesteps = 100,
|
||||||
cond_drop_prob = 0.2
|
image_cond_drop_prob = 0.1,
|
||||||
|
text_cond_drop_prob = 0.5
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
# mock images (get a lot of this)
|
# mock images (get a lot of this)
|
||||||
|
|||||||
@@ -173,12 +173,12 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
name = 'ViT-B/32'
|
name = 'ViT-B/32'
|
||||||
):
|
):
|
||||||
import clip
|
import clip
|
||||||
openai_clip, _ = clip.load(name)
|
openai_clip, preprocess = clip.load(name)
|
||||||
super().__init__(openai_clip)
|
super().__init__(openai_clip)
|
||||||
|
|
||||||
text_attention_final = self.find_layer('ln_final')
|
text_attention_final = self.find_layer('ln_final')
|
||||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||||
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
self.clip_normalize = preprocess.transforms[-1]
|
||||||
self.cleared = False
|
self.cleared = False
|
||||||
|
|
||||||
def find_layer(self, layer):
|
def find_layer(self, layer):
|
||||||
@@ -1174,7 +1174,7 @@ class Unet(nn.Module):
|
|||||||
if cond_scale == 1:
|
if cond_scale == 1:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
|
null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs)
|
||||||
return null_logits + (logits - null_logits) * cond_scale
|
return null_logits + (logits - null_logits) * cond_scale
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1185,7 +1185,8 @@ class Unet(nn.Module):
|
|||||||
image_embed,
|
image_embed,
|
||||||
lowres_cond_img = None,
|
lowres_cond_img = None,
|
||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
cond_drop_prob = 0.,
|
image_cond_drop_prob = 0.,
|
||||||
|
text_cond_drop_prob = 0.,
|
||||||
blur_sigma = None,
|
blur_sigma = None,
|
||||||
blur_kernel_size = None
|
blur_kernel_size = None
|
||||||
):
|
):
|
||||||
@@ -1204,8 +1205,10 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
# conditional dropout
|
# conditional dropout
|
||||||
|
|
||||||
keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
|
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
||||||
keep_mask = rearrange(keep_mask, 'b -> b 1 1')
|
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
|
||||||
|
|
||||||
|
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
|
||||||
|
|
||||||
# mask out image embedding depending on condition dropout
|
# mask out image embedding depending on condition dropout
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
@@ -1216,7 +1219,7 @@ class Unet(nn.Module):
|
|||||||
image_tokens = self.image_to_cond(image_embed)
|
image_tokens = self.image_to_cond(image_embed)
|
||||||
|
|
||||||
image_tokens = torch.where(
|
image_tokens = torch.where(
|
||||||
keep_mask,
|
image_keep_mask,
|
||||||
image_tokens,
|
image_tokens,
|
||||||
self.null_image_embed
|
self.null_image_embed
|
||||||
)
|
)
|
||||||
@@ -1228,7 +1231,7 @@ class Unet(nn.Module):
|
|||||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||||
text_tokens = self.text_to_cond(text_encodings)
|
text_tokens = self.text_to_cond(text_encodings)
|
||||||
text_tokens = torch.where(
|
text_tokens = torch.where(
|
||||||
keep_mask,
|
text_keep_mask,
|
||||||
text_tokens,
|
text_tokens,
|
||||||
self.null_text_embed[:, :text_tokens.shape[1]]
|
self.null_text_embed[:, :text_tokens.shape[1]]
|
||||||
)
|
)
|
||||||
@@ -1318,7 +1321,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
clip,
|
clip,
|
||||||
vae = tuple(),
|
vae = tuple(),
|
||||||
timesteps = 1000,
|
timesteps = 1000,
|
||||||
cond_drop_prob = 0.2,
|
image_cond_drop_prob = 0.1,
|
||||||
|
text_cond_drop_prob = 0.5,
|
||||||
loss_type = 'l1',
|
loss_type = 'l1',
|
||||||
beta_schedule = 'cosine',
|
beta_schedule = 'cosine',
|
||||||
predict_x_start = False,
|
predict_x_start = False,
|
||||||
@@ -1402,7 +1406,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
# classifier free guidance
|
# classifier free guidance
|
||||||
|
|
||||||
self.cond_drop_prob = cond_drop_prob
|
self.image_cond_drop_prob = image_cond_drop_prob
|
||||||
|
self.text_cond_drop_prob = text_cond_drop_prob
|
||||||
|
|
||||||
def get_unet(self, unet_number):
|
def get_unet(self, unet_number):
|
||||||
assert 0 < unet_number <= len(self.unets)
|
assert 0 < unet_number <= len(self.unets)
|
||||||
@@ -1484,7 +1489,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
image_embed = image_embed,
|
image_embed = image_embed,
|
||||||
text_encodings = text_encodings,
|
text_encodings = text_encodings,
|
||||||
lowres_cond_img = lowres_cond_img,
|
lowres_cond_img = lowres_cond_img,
|
||||||
cond_drop_prob = self.cond_drop_prob
|
image_cond_drop_prob = self.image_cond_drop_prob,
|
||||||
|
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||||
)
|
)
|
||||||
|
|
||||||
target = noise if not predict_x_start else x_start
|
target = noise if not predict_x_start else x_start
|
||||||
|
|||||||
Reference in New Issue
Block a user