mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-15 02:14:26 +01:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f19c99ecb0 | ||
|
|
721a444686 | ||
|
|
63450b466d | ||
|
|
20e7eb5a9b | ||
|
|
e2f9615afa | ||
|
|
0d1c07c803 |
23
README.md
23
README.md
@@ -47,7 +47,7 @@ clip = CLIP(
|
|||||||
use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
|
use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
|
||||||
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
|
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
|
||||||
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
|
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
|
||||||
use_visual_ssl = True, # whether to do self supervised learning on iages
|
use_visual_ssl = True, # whether to do self supervised learning on images
|
||||||
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
|
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
|
||||||
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
|
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
|
||||||
text_ssl_loss_weight = 0.05, # weight for text MLM loss
|
text_ssl_loss_weight = 0.05, # weight for text MLM loss
|
||||||
@@ -110,7 +110,8 @@ decoder = Decoder(
|
|||||||
unet = unet,
|
unet = unet,
|
||||||
clip = clip,
|
clip = clip,
|
||||||
timesteps = 100,
|
timesteps = 100,
|
||||||
cond_drop_prob = 0.2
|
image_cond_drop_prob = 0.1,
|
||||||
|
text_cond_drop_prob = 0.5
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
# mock images (get a lot of this)
|
# mock images (get a lot of this)
|
||||||
@@ -229,7 +230,8 @@ decoder = Decoder(
|
|||||||
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||||
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
|
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
|
||||||
timesteps = 1000,
|
timesteps = 1000,
|
||||||
cond_drop_prob = 0.2
|
image_cond_drop_prob = 0.1,
|
||||||
|
text_cond_drop_prob = 0.5
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
# mock images (get a lot of this)
|
# mock images (get a lot of this)
|
||||||
@@ -348,7 +350,8 @@ decoder = Decoder(
|
|||||||
image_sizes = (128, 256),
|
image_sizes = (128, 256),
|
||||||
clip = clip,
|
clip = clip,
|
||||||
timesteps = 100,
|
timesteps = 100,
|
||||||
cond_drop_prob = 0.2,
|
image_cond_drop_prob = 0.1,
|
||||||
|
text_cond_drop_prob = 0.5,
|
||||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
@@ -499,9 +502,7 @@ loss.backward()
|
|||||||
|
|
||||||
Although there is the possibility they are using an unreleased, more powerful CLIP, you can use one of the released ones, if you do not wish to train your own CLIP from scratch. This will also allow the community to more quickly validate the conclusions of the paper.
|
Although there is the possibility they are using an unreleased, more powerful CLIP, you can use one of the released ones, if you do not wish to train your own CLIP from scratch. This will also allow the community to more quickly validate the conclusions of the paper.
|
||||||
|
|
||||||
First you'll need to install <a href="https://github.com/openai/CLIP#usage">the prerequisites</a>
|
To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it into the `DiffusionPrior` or `Decoder` like so
|
||||||
|
|
||||||
Then to use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it into the `DiffusionPrior` or `Decoder` like so
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
@@ -560,7 +561,8 @@ decoder = Decoder(
|
|||||||
image_sizes = (128, 256),
|
image_sizes = (128, 256),
|
||||||
clip = clip,
|
clip = clip,
|
||||||
timesteps = 100,
|
timesteps = 100,
|
||||||
cond_drop_prob = 0.2,
|
image_cond_drop_prob = 0.1,
|
||||||
|
text_cond_drop_prob = 0.5,
|
||||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
@@ -618,7 +620,7 @@ clip = CLIP(
|
|||||||
# 3 unets for the decoder (a la cascading DDPM)
|
# 3 unets for the decoder (a la cascading DDPM)
|
||||||
|
|
||||||
# first two unets are doing latent diffusion
|
# first two unets are doing latent diffusion
|
||||||
# vqgan-vae must be trained before hand
|
# vqgan-vae must be trained beforehand
|
||||||
|
|
||||||
vae1 = VQGanVAE(
|
vae1 = VQGanVAE(
|
||||||
dim = 32,
|
dim = 32,
|
||||||
@@ -671,7 +673,8 @@ decoder = Decoder(
|
|||||||
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||||
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
|
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
|
||||||
timesteps = 100,
|
timesteps = 100,
|
||||||
cond_drop_prob = 0.2
|
image_cond_drop_prob = 0.1,
|
||||||
|
text_cond_drop_prob = 0.5
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
# mock images (get a lot of this)
|
# mock images (get a lot of this)
|
||||||
|
|||||||
@@ -172,17 +172,13 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
self,
|
self,
|
||||||
name = 'ViT-B/32'
|
name = 'ViT-B/32'
|
||||||
):
|
):
|
||||||
try:
|
import clip
|
||||||
import clip
|
openai_clip, preprocess = clip.load(name)
|
||||||
except ImportError:
|
|
||||||
print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage')
|
|
||||||
|
|
||||||
openai_clip, _ = clip.load(name)
|
|
||||||
super().__init__(openai_clip)
|
super().__init__(openai_clip)
|
||||||
|
|
||||||
text_attention_final = self.find_layer('ln_final')
|
text_attention_final = self.find_layer('ln_final')
|
||||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||||
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
self.clip_normalize = preprocess.transforms[-1]
|
||||||
self.cleared = False
|
self.cleared = False
|
||||||
|
|
||||||
def find_layer(self, layer):
|
def find_layer(self, layer):
|
||||||
@@ -688,14 +684,14 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
|
|
||||||
# classifier free guidance
|
# classifier free guidance
|
||||||
|
|
||||||
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
|
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
|
||||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
|
keep_mask = rearrange(keep_mask, 'b -> b 1')
|
||||||
|
|
||||||
mask &= cond_prob_mask
|
mask &= keep_mask
|
||||||
|
|
||||||
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
|
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
|
||||||
|
|
||||||
mask = torch.cat((mask, cond_prob_mask), dim = 1)
|
mask = torch.cat((mask, keep_mask), dim = 1)
|
||||||
|
|
||||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||||
# but let's just do it right
|
# but let's just do it right
|
||||||
@@ -1178,7 +1174,7 @@ class Unet(nn.Module):
|
|||||||
if cond_scale == 1:
|
if cond_scale == 1:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
|
null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs)
|
||||||
return null_logits + (logits - null_logits) * cond_scale
|
return null_logits + (logits - null_logits) * cond_scale
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1189,7 +1185,8 @@ class Unet(nn.Module):
|
|||||||
image_embed,
|
image_embed,
|
||||||
lowres_cond_img = None,
|
lowres_cond_img = None,
|
||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
cond_drop_prob = 0.,
|
image_cond_drop_prob = 0.,
|
||||||
|
text_cond_drop_prob = 0.,
|
||||||
blur_sigma = None,
|
blur_sigma = None,
|
||||||
blur_kernel_size = None
|
blur_kernel_size = None
|
||||||
):
|
):
|
||||||
@@ -1208,8 +1205,10 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
# conditional dropout
|
# conditional dropout
|
||||||
|
|
||||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
||||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
|
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
|
||||||
|
|
||||||
|
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
|
||||||
|
|
||||||
# mask out image embedding depending on condition dropout
|
# mask out image embedding depending on condition dropout
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
@@ -1220,7 +1219,7 @@ class Unet(nn.Module):
|
|||||||
image_tokens = self.image_to_cond(image_embed)
|
image_tokens = self.image_to_cond(image_embed)
|
||||||
|
|
||||||
image_tokens = torch.where(
|
image_tokens = torch.where(
|
||||||
cond_prob_mask,
|
image_keep_mask,
|
||||||
image_tokens,
|
image_tokens,
|
||||||
self.null_image_embed
|
self.null_image_embed
|
||||||
)
|
)
|
||||||
@@ -1232,7 +1231,7 @@ class Unet(nn.Module):
|
|||||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||||
text_tokens = self.text_to_cond(text_encodings)
|
text_tokens = self.text_to_cond(text_encodings)
|
||||||
text_tokens = torch.where(
|
text_tokens = torch.where(
|
||||||
cond_prob_mask,
|
text_keep_mask,
|
||||||
text_tokens,
|
text_tokens,
|
||||||
self.null_text_embed[:, :text_tokens.shape[1]]
|
self.null_text_embed[:, :text_tokens.shape[1]]
|
||||||
)
|
)
|
||||||
@@ -1322,7 +1321,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
clip,
|
clip,
|
||||||
vae = tuple(),
|
vae = tuple(),
|
||||||
timesteps = 1000,
|
timesteps = 1000,
|
||||||
cond_drop_prob = 0.2,
|
image_cond_drop_prob = 0.1,
|
||||||
|
text_cond_drop_prob = 0.5,
|
||||||
loss_type = 'l1',
|
loss_type = 'l1',
|
||||||
beta_schedule = 'cosine',
|
beta_schedule = 'cosine',
|
||||||
predict_x_start = False,
|
predict_x_start = False,
|
||||||
@@ -1406,7 +1406,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
# classifier free guidance
|
# classifier free guidance
|
||||||
|
|
||||||
self.cond_drop_prob = cond_drop_prob
|
self.image_cond_drop_prob = image_cond_drop_prob
|
||||||
|
self.text_cond_drop_prob = text_cond_drop_prob
|
||||||
|
|
||||||
def get_unet(self, unet_number):
|
def get_unet(self, unet_number):
|
||||||
assert 0 < unet_number <= len(self.unets)
|
assert 0 < unet_number <= len(self.unets)
|
||||||
@@ -1488,7 +1489,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
image_embed = image_embed,
|
image_embed = image_embed,
|
||||||
text_encodings = text_encodings,
|
text_encodings = text_encodings,
|
||||||
lowres_cond_img = lowres_cond_img,
|
lowres_cond_img = lowres_cond_img,
|
||||||
cond_drop_prob = self.cond_drop_prob
|
image_cond_drop_prob = self.image_cond_drop_prob,
|
||||||
|
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||||
)
|
)
|
||||||
|
|
||||||
target = noise if not predict_x_start else x_start
|
target = noise if not predict_x_start else x_start
|
||||||
@@ -1636,4 +1638,3 @@ class DALLE2(nn.Module):
|
|||||||
return images[0]
|
return images[0]
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.0.71',
|
version = '0.0.74',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
@@ -23,6 +23,7 @@ setup(
|
|||||||
],
|
],
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'click',
|
'click',
|
||||||
|
'clip-anytorch',
|
||||||
'einops>=0.4',
|
'einops>=0.4',
|
||||||
'einops-exts>=0.0.3',
|
'einops-exts>=0.0.3',
|
||||||
'kornia>=0.5.4',
|
'kornia>=0.5.4',
|
||||||
|
|||||||
Reference in New Issue
Block a user