mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 19:44:26 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c30544b73a | ||
|
|
bdf5e9c009 | ||
|
|
9878be760b |
54
README.md
54
README.md
@@ -348,7 +348,8 @@ decoder = Decoder(
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
cond_drop_prob = 0.2,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
@@ -445,6 +446,55 @@ loss.backward()
|
||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||
```
|
||||
|
||||
You can also completely go `CLIP`-less, in which case you will need to pass in the `image_embed_dim` into the `DiffusionPrior` on initialization
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior
|
||||
|
||||
# setup prior network, which contains an autoregressive transformer
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
).cuda()
|
||||
|
||||
# diffusion prior network, which contains the CLIP and network (with transformer) above
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
image_embed_dim = 512, # this needs to be set
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2,
|
||||
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
|
||||
).cuda()
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
|
||||
# precompute the text and image embeddings
|
||||
# here using the diffusion prior class, but could be done with CLIP alone
|
||||
|
||||
clip_image_embeds = torch.randn(4, 512).cuda()
|
||||
clip_text_embeds = torch.randn(4, 512).cuda()
|
||||
|
||||
# feed text and images into diffusion prior network
|
||||
|
||||
loss = diffusion_prior(
|
||||
text_embed = clip_text_embeds,
|
||||
image_embed = clip_image_embeds
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# do the above for many many many steps
|
||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||
```
|
||||
|
||||
## Experimental
|
||||
|
||||
### DALL-E2 with Latent Diffusion
|
||||
@@ -593,7 +643,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
||||
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
||||
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
|
||||
- [ ] spend one day cleaning up tech debt in decoder
|
||||
- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in - use inheritance just this once for sharing logic between decoder and prior network ddpms
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
|
||||
@@ -486,7 +486,10 @@ class DiffusionPrior(nn.Module):
|
||||
self,
|
||||
net,
|
||||
*,
|
||||
clip,
|
||||
clip = None,
|
||||
image_embed_dim = None,
|
||||
image_size = None,
|
||||
image_channels = 3,
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = "l1",
|
||||
@@ -495,14 +498,18 @@ class DiffusionPrior(nn.Module):
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
self.clip = clip
|
||||
|
||||
if exists(clip):
|
||||
assert isinstance(clip, CLIP)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
self.clip = clip
|
||||
else:
|
||||
assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'
|
||||
self.clip = None
|
||||
|
||||
self.net = net
|
||||
self.image_embed_dim = clip.dim_latent
|
||||
self.channels = clip.image_channels
|
||||
self.image_size = clip.image_size
|
||||
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
|
||||
self.channels = default(image_channels, lambda: clip.image_channels)
|
||||
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
@@ -559,6 +566,8 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
assert exists(self.clip)
|
||||
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
@@ -566,6 +575,8 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_cond(self, text):
|
||||
assert exists(self.clip)
|
||||
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
||||
text_embed = self.clip.to_text_latent(text_cls)
|
||||
@@ -894,6 +905,7 @@ class Unet(nn.Module):
|
||||
sparse_attn_window = 8, # window size for sparse attention
|
||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||
cond_on_text_encodings = False,
|
||||
max_text_len = 256,
|
||||
cond_on_image_embeds = False,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -944,7 +956,7 @@ class Unet(nn.Module):
|
||||
# for classifier free guidance
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||
|
||||
# attention related params
|
||||
|
||||
@@ -1072,7 +1084,7 @@ class Unet(nn.Module):
|
||||
text_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
text_tokens,
|
||||
self.null_text_embed
|
||||
self.null_text_embed[:, :text_tokens.shape[1]]
|
||||
)
|
||||
|
||||
# main conditioning tokens (c)
|
||||
@@ -1170,6 +1182,7 @@ class Decoder(nn.Module):
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
@@ -1178,6 +1191,8 @@ class Decoder(nn.Module):
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
@@ -1421,6 +1436,8 @@ class Decoder(nn.Module):
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
|
||||
img = None
|
||||
|
||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
@@ -1481,6 +1498,8 @@ class Decoder(nn.Module):
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
|
||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||
image = resize_image_to(image, target_image_size)
|
||||
|
||||
@@ -1508,7 +1527,9 @@ class DALLE2(nn.Module):
|
||||
assert isinstance(decoder, Decoder)
|
||||
self.prior = prior
|
||||
self.decoder = decoder
|
||||
|
||||
self.prior_num_samples = prior_num_samples
|
||||
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
@@ -1525,7 +1546,9 @@ class DALLE2(nn.Module):
|
||||
text = tokenizer.tokenize(text).to(device)
|
||||
|
||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
|
||||
|
||||
text_cond = text if self.decoder_need_text_cond else None
|
||||
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
|
||||
|
||||
if one_text:
|
||||
return images[0]
|
||||
|
||||
Reference in New Issue
Block a user