Compare commits

...

4 Commits

3 changed files with 16 additions and 18 deletions

View File

@@ -197,10 +197,10 @@ clip = CLIP(
dim_image = 512, dim_image = 512,
dim_latent = 512, dim_latent = 512,
num_text_tokens = 49408, num_text_tokens = 49408,
text_enc_depth = 1, text_enc_depth = 6,
text_seq_len = 256, text_seq_len = 256,
text_heads = 8, text_heads = 8,
visual_enc_depth = 1, visual_enc_depth = 6,
visual_image_size = 256, visual_image_size = 256,
visual_patch_size = 32, visual_patch_size = 32,
visual_heads = 8 visual_heads = 8
@@ -209,14 +209,15 @@ clip = CLIP(
# 2 unets for the decoder (a la cascading DDPM) # 2 unets for the decoder (a la cascading DDPM)
unet1 = Unet( unet1 = Unet(
dim = 16, dim = 32,
image_embed_dim = 512, image_embed_dim = 512,
cond_dim = 128,
channels = 3, channels = 3,
dim_mults = (1, 2, 4, 8) dim_mults = (1, 2, 4, 8)
).cuda() ).cuda()
unet2 = Unet( unet2 = Unet(
dim = 16, dim = 32,
image_embed_dim = 512, image_embed_dim = 512,
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
@@ -228,8 +229,8 @@ unet2 = Unet(
decoder = Decoder( decoder = Decoder(
clip = clip, clip = clip,
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here) unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 100, timesteps = 1000,
cond_drop_prob = 0.2 cond_drop_prob = 0.2
).cuda() ).cuda()
@@ -256,7 +257,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512) images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
``` ```
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer) Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
```python ```python
from dalle2_pytorch import DALLE2 from dalle2_pytorch import DALLE2
@@ -408,15 +409,10 @@ Offer training wrappers
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference) - [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper) - [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions - [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [ ] use an image resolution cutoff and do cross attention conditioning only if resources allow, and MLP + sum conditioning on rest
- [ ] make unet more configurable
- [ ] figure out some factory methods to make cascading unet instantiations less error-prone
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately) - [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] become an expert with unets, port learnings over to https://github.com/lucidrains/x-unet
- [ ] train on a toy task, offer in colab
- [ ] add attention to unet - apply some personal tricks with efficient attention - use the sparse attention mechanism from https://github.com/lucidrains/vit-pytorch#maxvit
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting) - [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007 (also in separate file as experimental) build out https://github.com/lucidrains/x-unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, add efficient attention (conditional on resolution), port all learnings over to https://github.com/lucidrains/x-unet
- [ ] train on a toy task, offer in colab
## Citations ## Citations

View File

@@ -1214,7 +1214,7 @@ class Decoder(nn.Module):
return img return img
def forward(self, image, text = None, unet_number = None): def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1) unet_number = default(unet_number, 1)
assert 1 <= unet_number <= len(self.unets) assert 1 <= unet_number <= len(self.unets)
@@ -1230,8 +1230,10 @@ class Decoder(nn.Module):
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image) if not exists(image_embed):
text_encodings = self.get_text_encodings(text) if exists(text) else None image_embed = self.get_image_embed(image)
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
lowres_cond_img = image if index > 0 else None lowres_cond_img = image if index > 0 else None
ddpm_image = resize_image_to(image, target_image_size) ddpm_image = resize_image_to(image, target_image_size)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.22', version = '0.0.24',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',