mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-13 12:04:24 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9878be760b | ||
|
|
7ba6357c05 | ||
|
|
76e063e8b7 |
72
README.md
72
README.md
@@ -348,7 +348,8 @@ decoder = Decoder(
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
cond_drop_prob = 0.2,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
@@ -376,6 +377,75 @@ You can also train the decoder on images of greater than the size (say 512x512)
|
||||
|
||||
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
|
||||
|
||||
## Training on Preprocessed CLIP Embeddings
|
||||
|
||||
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
|
||||
|
||||
Working example below
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
|
||||
|
||||
# get trained CLIP from step one
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 6,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 6,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8,
|
||||
).cuda()
|
||||
|
||||
# setup prior network, which contains an autoregressive transformer
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
).cuda()
|
||||
|
||||
# diffusion prior network, which contains the CLIP and network (with transformer) above
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2,
|
||||
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
|
||||
).cuda()
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
|
||||
# precompute the text and image embeddings
|
||||
# here using the diffusion prior class, but could be done with CLIP alone
|
||||
|
||||
clip_image_embeds = diffusion_prior.get_image_embed(images)
|
||||
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
|
||||
|
||||
# feed text and images into diffusion prior network
|
||||
|
||||
loss = diffusion_prior(
|
||||
text_embed = clip_text_embeds,
|
||||
image_embed = clip_image_embeds
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# do the above for many many many steps
|
||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||
```
|
||||
|
||||
## Experimental
|
||||
|
||||
### DALL-E2 with Latent Diffusion
|
||||
|
||||
@@ -421,25 +421,41 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
image_embed,
|
||||
diffusion_timesteps,
|
||||
*,
|
||||
text_encodings,
|
||||
text_embed,
|
||||
text_encodings = None,
|
||||
mask = None,
|
||||
cond_drop_prob = 0.2
|
||||
):
|
||||
batch, text_enc_len, device = image_embed.shape[0], text_encodings.shape[-2], image_embed.device
|
||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||
|
||||
# in section 2.2, last paragraph
|
||||
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||
|
||||
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
|
||||
|
||||
# make text encodings optional
|
||||
# although the paper seems to suggest it is present <--
|
||||
|
||||
if not exists(text_encodings):
|
||||
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
||||
|
||||
if not exists(mask):
|
||||
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
|
||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
|
||||
|
||||
mask &= cond_prob_mask
|
||||
|
||||
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
|
||||
|
||||
mask = torch.cat((mask, cond_prob_mask), dim = 1)
|
||||
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
|
||||
if exists(mask):
|
||||
not_all_masked_out = mask.any(dim = -1)
|
||||
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
@@ -455,16 +471,6 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
learned_queries
|
||||
), dim = -2)
|
||||
|
||||
# mask if it doesn't exist
|
||||
|
||||
if not exists(mask):
|
||||
mask = torch.ones((batch, text_enc_len), device = device, dtype = torch.bool)
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
|
||||
mask &= rearrange(cond_prob_mask, 'b -> b 1')
|
||||
|
||||
# attend
|
||||
|
||||
tokens = self.causal_transformer(tokens, mask = mask)
|
||||
@@ -486,6 +492,7 @@ class DiffusionPrior(nn.Module):
|
||||
loss_type = "l1",
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
@@ -496,7 +503,9 @@ class DiffusionPrior(nn.Module):
|
||||
self.image_embed_dim = clip.dim_latent
|
||||
self.channels = clip.image_channels
|
||||
self.image_size = clip.image_size
|
||||
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
self.predict_x_start = predict_x_start
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
@@ -561,6 +570,10 @@ class DiffusionPrior(nn.Module):
|
||||
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
||||
text_embed = self.clip.to_text_latent(text_cls)
|
||||
text_embed = l2norm(text_embed)
|
||||
|
||||
if not self.condition_on_text_encodings:
|
||||
return dict(text_embed = text_embed)
|
||||
|
||||
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
@@ -679,13 +692,41 @@ class DiffusionPrior(nn.Module):
|
||||
top_image_embeds = image_embeds.gather(1, top_sim_indices)
|
||||
return rearrange(top_image_embeds, 'b 1 d -> b d')
|
||||
|
||||
def forward(self, text, image, *args, **kwargs):
|
||||
b, device, img_size, = image.shape[0], image.device, self.image_size
|
||||
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
||||
def forward(
|
||||
self,
|
||||
text = None,
|
||||
image = None,
|
||||
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
|
||||
image_embed = None,
|
||||
text_encodings = None, # as well as CLIP text encodings
|
||||
text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
|
||||
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
|
||||
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
image_embed = self.get_image_embed(image)
|
||||
text_cond = self.get_text_cond(text)
|
||||
if exists(image):
|
||||
image_embed = self.get_image_embed(image)
|
||||
|
||||
# calculate text conditionings, based on what is passed in
|
||||
|
||||
if exists(text):
|
||||
text_cond = self.get_text_cond(text)
|
||||
else:
|
||||
text_cond = dict(
|
||||
text_embed = text_embed,
|
||||
text_encodings = text_encodings,
|
||||
mask = text_mask
|
||||
)
|
||||
|
||||
# timestep conditioning from ddpm
|
||||
|
||||
batch, device = image_embed.shape[0], image_embed.device
|
||||
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
|
||||
|
||||
# calculate forward loss
|
||||
|
||||
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
||||
return loss
|
||||
@@ -853,6 +894,7 @@ class Unet(nn.Module):
|
||||
sparse_attn_window = 8, # window size for sparse attention
|
||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||
cond_on_text_encodings = False,
|
||||
max_text_len = 256,
|
||||
cond_on_image_embeds = False,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -903,7 +945,7 @@ class Unet(nn.Module):
|
||||
# for classifier free guidance
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||
|
||||
# attention related params
|
||||
|
||||
@@ -1031,7 +1073,7 @@ class Unet(nn.Module):
|
||||
text_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
text_tokens,
|
||||
self.null_text_embed
|
||||
self.null_text_embed[:, :text_tokens.shape[1]]
|
||||
)
|
||||
|
||||
# main conditioning tokens (c)
|
||||
@@ -1129,6 +1171,7 @@ class Decoder(nn.Module):
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
@@ -1137,6 +1180,8 @@ class Decoder(nn.Module):
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
@@ -1380,6 +1425,8 @@ class Decoder(nn.Module):
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
|
||||
img = None
|
||||
|
||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
@@ -1440,6 +1487,8 @@ class Decoder(nn.Module):
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
|
||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||
image = resize_image_to(image, target_image_size)
|
||||
|
||||
@@ -1467,7 +1516,9 @@ class DALLE2(nn.Module):
|
||||
assert isinstance(decoder, Decoder)
|
||||
self.prior = prior
|
||||
self.decoder = decoder
|
||||
|
||||
self.prior_num_samples = prior_num_samples
|
||||
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
@@ -1484,7 +1535,9 @@ class DALLE2(nn.Module):
|
||||
text = tokenizer.tokenize(text).to(device)
|
||||
|
||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
|
||||
|
||||
text_cond = text if self.decoder_need_text_cond else None
|
||||
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
|
||||
|
||||
if one_text:
|
||||
return images[0]
|
||||
|
||||
Reference in New Issue
Block a user