Compare commits

...

24 Commits

Author SHA1 Message Date
Phil Wang
b588286288 fix version 2022-05-30 11:06:34 -07:00
Phil Wang
b693e0be03 default number of resnet blocks per layer in unet to 2 (in imagen it was 3 for base 64x64) 2022-05-30 10:06:48 -07:00
Phil Wang
a0bed30a84 additional conditioning on image embedding by summing to time embeddings (for FiLM like conditioning in subsequent layers), from passage found in paper by @mhh0318 2022-05-30 09:26:51 -07:00
zion
387c5bf774 quick patch for new prior loader (#123) 2022-05-29 16:25:53 -07:00
Phil Wang
a13d2d89c5 0.5.7 2022-05-29 07:40:25 -07:00
zion
44d4b1bba9 overhaul prior dataloader (#122)
add readme for loader
2022-05-29 07:39:59 -07:00
Phil Wang
f12a7589c5 commit to trying out grid attention 2022-05-26 12:56:10 -07:00
Phil Wang
b8af2210df make sure diffusion prior can be instantiated from pydantic class without clip 2022-05-26 08:47:30 -07:00
Phil Wang
f4fe6c570d allow for full customization of number of resnet blocks per down or upsampling layers in unet, as in imagen 2022-05-26 08:33:31 -07:00
Phil Wang
645e207441 credit assignment 2022-05-26 08:16:03 -07:00
Phil Wang
00743b3a0b update 2022-05-26 08:12:25 -07:00
Phil Wang
01589aff6a cite maxvit properly 2022-05-26 07:12:25 -07:00
Phil Wang
7ecfd76cc0 fix evaluation config splat in training decoder script 2022-05-26 07:11:31 -07:00
Phil Wang
6161b61c55 0.5.4 2022-05-25 09:32:17 -07:00
zion
1ed0f9d80b use deterministic optimizer params (#116) 2022-05-25 09:31:43 -07:00
Phil Wang
f326a95e26 0.5.3 2022-05-25 09:07:28 -07:00
zion
d7a0a2ce4b add more support for configuring prior (#113) 2022-05-25 09:06:50 -07:00
Phil Wang
f23fab7ef7 switch over to scale shift conditioning, as it seems like Imagen and Glide used it and it may be important 2022-05-24 21:46:12 -07:00
Phil Wang
857b9fbf1e allow for one to stop grouping out weight decayable parameters, to debug optimizer state dict problem 2022-05-24 21:42:32 -07:00
Phil Wang
8864fd0aa7 bring in the dynamic thresholding technique from the Imagen paper, which purportedly improves classifier free guidance for the cascading ddpm 2022-05-24 18:15:14 -07:00
Phil Wang
72bf159331 update 2022-05-24 08:25:40 -07:00
Phil Wang
e5e47cfecb link to aidan's test run 2022-05-23 12:41:46 -07:00
Phil Wang
fa533962bd just use an assert to make sure clip image channels is never different than the channels of the diffusion prior and decoder, if clip is given 2022-05-22 22:43:14 -07:00
Phil Wang
276abf337b fix and cleanup image size determination logic in decoder 2022-05-22 22:28:45 -07:00
15 changed files with 673 additions and 286 deletions

View File

@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a> Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place. As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lucidrains/imagen-pytorch">here</a>. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.
## Status ## Status
@@ -24,9 +24,11 @@ There was enough interest for a <a href="https://github.com/lucidrains/dalle2-ja
*ongoing at 21k steps* *ongoing at 21k steps*
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
## Pre-Trained Models ## Pre-Trained Models
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>. - LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
- Decoder 🚧 - Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
- DALL-E 2 🚧 - DALL-E 2 🚧
## Install ## Install
@@ -1048,6 +1050,7 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management - <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs - <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice - <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
... and many others. Thank you! 🙏 ... and many others. Thank you! 🙏
@@ -1091,7 +1094,7 @@ This library would not have gotten to this working state without the help of
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder - [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference - [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824 - [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove - [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783 - [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training - [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
@@ -1140,8 +1143,9 @@ This library would not have gotten to this working state without the help of
```bibtex ```bibtex
@inproceedings{Tu2022MaxViTMV, @inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer}, title = {MaxViT: Multi-Axis Vision Transformer},
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li}, author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
year = {2022} year = {2022},
url = {https://arxiv.org/abs/2204.01697}
} }
``` ```
@@ -1195,4 +1199,12 @@ This library would not have gotten to this working state without the help of
} }
``` ```
```bibtex
@misc{Saharia2022,
title = {Imagen: unprecedented photorealism × deep level of language understanding},
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
year = {2022}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -0,0 +1,70 @@
{
"prior": {
"clip": {
"make": "x-clip",
"model": "ViT-L/14",
"base_model_kwargs": {
"dim_text": 768,
"dim_image": 768,
"dim_latent": 768
}
},
"net": {
"dim": 768,
"depth": 12,
"num_timesteps": 1000,
"num_time_embeds": 1,
"num_image_embeds": 1,
"num_text_embeds": 1,
"dim_head": 64,
"heads": 12,
"ff_mult": 4,
"norm_out": true,
"attn_dropout": 0.0,
"ff_dropout": 0.0,
"final_proj": true,
"normformer": true,
"rotary_emb": true
},
"image_embed_dim": 768,
"image_size": 224,
"image_channels": 3,
"timesteps": 1000,
"cond_drop_prob": 0.1,
"loss_type": "l2",
"predict_x_start": true,
"beta_schedule": "cosine",
"condition_on_text_encodings": true
},
"data": {
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
"batch_size": 256,
"splits": {
"train": 0.9,
"val": 1e-7,
"test": 0.0999999
}
},
"train": {
"epochs": 1,
"lr": 1.1e-4,
"wd": 6.02e-2,
"max_grad_norm": 0.5,
"use_ema": true,
"amp": false,
"save_every": 10000
},
"load": {
"source": null,
"resume": false
},
"tracker": {
"tracker_type": "wandb",
"data_path": "./prior_checkpoints",
"wandb_entity": "laion",
"wandb_project": "diffusion-prior",
"verbose": true
}
}

View File

@@ -1,3 +1,4 @@
from dalle2_pytorch.version import __version__
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer

View File

@@ -890,6 +890,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
) )
if exists(clip): if exists(clip):
assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
if isinstance(clip, CLIP): if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides) clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa): elif isinstance(clip, CoCa):
@@ -1105,13 +1107,20 @@ class Block(nn.Module):
groups = 8 groups = 8
): ):
super().__init__() super().__init__()
self.block = nn.Sequential( self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
nn.Conv2d(dim, dim_out, 3, padding = 1), self.norm = nn.GroupNorm(groups, dim_out)
nn.GroupNorm(groups, dim_out), self.act = nn.SiLU()
nn.SiLU()
) def forward(self, x, scale_shift = None):
def forward(self, x): x = self.project(x)
return self.block(x) x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module): class ResnetBlock(nn.Module):
def __init__( def __init__(
@@ -1130,7 +1139,7 @@ class ResnetBlock(nn.Module):
if exists(time_cond_dim): if exists(time_cond_dim):
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
nn.SiLU(), nn.SiLU(),
nn.Linear(time_cond_dim, dim_out) nn.Linear(time_cond_dim, dim_out * 2)
) )
self.cross_attn = None self.cross_attn = None
@@ -1150,11 +1159,14 @@ class ResnetBlock(nn.Module):
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, cond = None, time_emb = None): def forward(self, x, cond = None, time_emb = None):
h = self.block1(x)
scale_shift = None
if exists(self.time_mlp) and exists(time_emb): if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb) time_emb = self.time_mlp(time_emb)
h = rearrange(time_emb, 'b c -> b c 1 1') + h time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, scale_shift = scale_shift)
if exists(self.cross_attn): if exists(self.cross_attn):
assert exists(cond) assert exists(cond)
@@ -1331,9 +1343,11 @@ class Unet(nn.Module):
cond_on_text_encodings = False, cond_on_text_encodings = False,
max_text_len = 256, max_text_len = 256,
cond_on_image_embeds = False, cond_on_image_embeds = False,
add_image_embeds_to_time = True, # alerted by @mhh0318 to a phrase in the paper - "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and adding CLIP embeddings to the existing timestep embedding"
init_dim = None, init_dim = None,
init_conv_kernel_size = 7, init_conv_kernel_size = 7,
resnet_groups = 8, resnet_groups = 8,
num_resnet_blocks = 2,
init_cross_embed_kernel_sizes = (3, 7, 15), init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False, cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4), cross_embed_downsample_kernel_sizes = (2, 4),
@@ -1383,11 +1397,16 @@ class Unet(nn.Module):
nn.Linear(time_cond_dim, time_cond_dim) nn.Linear(time_cond_dim, time_cond_dim)
) )
self.image_to_cond = nn.Sequential( self.image_to_tokens = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens), nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens) Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity() ) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
self.to_image_hiddens = nn.Sequential(
nn.Linear(image_embed_dim, time_cond_dim),
nn.GELU()
) if cond_on_image_embeds and add_image_embeds_to_time else None
self.norm_cond = nn.LayerNorm(cond_dim) self.norm_cond = nn.LayerNorm(cond_dim)
self.norm_mid_cond = nn.LayerNorm(cond_dim) self.norm_mid_cond = nn.LayerNorm(cond_dim)
@@ -1419,6 +1438,7 @@ class Unet(nn.Module):
# resnet block klass # resnet block klass
resnet_groups = cast_tuple(resnet_groups, len(in_out)) resnet_groups = cast_tuple(resnet_groups, len(in_out))
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
assert len(resnet_groups) == len(in_out) assert len(resnet_groups) == len(in_out)
@@ -1434,7 +1454,7 @@ class Unet(nn.Module):
self.ups = nn.ModuleList([]) self.ups = nn.ModuleList([])
num_resolutions = len(in_out) num_resolutions = len(in_out)
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)): for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
is_first = ind == 0 is_first = ind == 0
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None layer_cond_dim = cond_dim if not is_first else None
@@ -1442,7 +1462,7 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([ self.downs.append(nn.ModuleList([
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups), ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(), Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups), nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
downsample_klass(dim_out) if not is_last else nn.Identity() downsample_klass(dim_out) if not is_last else nn.Identity()
])) ]))
@@ -1452,14 +1472,14 @@ class Unet(nn.Module):
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))): for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups), reversed(num_resnet_blocks))):
is_last = ind >= (num_resolutions - 2) is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([ self.ups.append(nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups), ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(), Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups), nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
Upsample(dim_in) Upsample(dim_in)
])) ]))
@@ -1544,6 +1564,13 @@ class Unet(nn.Module):
time_tokens = self.to_time_tokens(time_hiddens) time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens) t = self.to_time_cond(time_hiddens)
# image embedding to be summed to time embedding
# discovered by @mhh0318 in the paper
if exists(image_embed) and exists(self.to_image_hiddens):
image_hiddens = self.to_image_hiddens(image_embed)
t = t + image_hiddens
# conditional dropout # conditional dropout
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device) image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
@@ -1557,7 +1584,7 @@ class Unet(nn.Module):
image_tokens = None image_tokens = None
if self.cond_on_image_embeds: if self.cond_on_image_embeds:
image_tokens = self.image_to_cond(image_embed) image_tokens = self.image_to_tokens(image_embed)
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
image_tokens = torch.where( image_tokens = torch.where(
@@ -1616,10 +1643,13 @@ class Unet(nn.Module):
hiddens = [] hiddens = []
for block1, sparse_attn, block2, downsample in self.downs: for init_block, sparse_attn, resnet_blocks, downsample in self.downs:
x = block1(x, c, t) x = init_block(x, c, t)
x = sparse_attn(x) x = sparse_attn(x)
x = block2(x, c, t)
for resnet_block in resnet_blocks:
x = resnet_block(x, c, t)
hiddens.append(x) hiddens.append(x)
x = downsample(x) x = downsample(x)
@@ -1630,11 +1660,14 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c, t) x = self.mid_block2(x, mid_c, t)
for block1, sparse_attn, block2, upsample in self.ups: for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1) x = torch.cat((x, hiddens.pop()), dim=1)
x = block1(x, c, t) x = init_block(x, c, t)
x = sparse_attn(x) x = sparse_attn(x)
x = block2(x, c, t)
for resnet_block in resnet_blocks:
x = resnet_block(x, c, t)
x = upsample(x) x = upsample(x)
return self.final_conv(x) return self.final_conv(x)
@@ -1702,6 +1735,8 @@ class Decoder(BaseGaussianDiffusion):
vb_loss_weight = 0.001, vb_loss_weight = 0.001,
unconditional = False, unconditional = False,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
use_dynamic_thres = False, # from the Imagen paper
dynamic_thres_percentile = 0.9
): ):
super().__init__( super().__init__(
beta_schedule = beta_schedule, beta_schedule = beta_schedule,
@@ -1710,12 +1745,19 @@ class Decoder(BaseGaussianDiffusion):
) )
self.unconditional = unconditional self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert self.unconditional or (exists(clip) or exists(image_size) or exists(image_sizes)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)' # text conditioning
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
self.condition_on_text_encodings = condition_on_text_encodings
# clip
self.clip = None self.clip = None
if exists(clip): if exists(clip):
assert not unconditional, 'clip must not be given if doing unconditional image training'
assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
if isinstance(clip, CLIP): if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides) clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa): elif isinstance(clip, CoCa):
@@ -1725,13 +1767,20 @@ class Decoder(BaseGaussianDiffusion):
assert isinstance(clip, BaseClipAdapter) assert isinstance(clip, BaseClipAdapter)
self.clip = clip self.clip = clip
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
else:
self.clip_image_size = default(image_size, lambda: image_sizes[-1])
self.channels = channels
self.condition_on_text_encodings = condition_on_text_encodings # determine image size, with image_size and image_sizes taking precedence
if exists(image_size) or exists(image_sizes):
assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'
image_size = default(image_size, lambda: image_sizes[-1])
elif exists(clip):
image_size = clip.image_size
else:
raise Error('either image_size, image_sizes, or clip must be given to decoder')
# channels
self.channels = channels
# automatically take care of ensuring that first unet is unconditional # automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet # while the rest of the unets are conditioned on the low resolution image produced by previous unet
@@ -1773,7 +1822,7 @@ class Decoder(BaseGaussianDiffusion):
# unet image sizes # unet image sizes
image_sizes = default(image_sizes, (self.clip_image_size,)) image_sizes = default(image_sizes, (image_size,))
image_sizes = tuple(sorted(set(image_sizes))) image_sizes = tuple(sorted(set(image_sizes)))
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}' assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
@@ -1810,7 +1859,13 @@ class Decoder(BaseGaussianDiffusion):
self.clip_denoised = clip_denoised self.clip_denoised = clip_denoised
self.clip_x_start = clip_x_start self.clip_x_start = clip_x_start
# dynamic thresholding settings, if clipping denoised during sampling
self.use_dynamic_thres = use_dynamic_thres
self.dynamic_thres_percentile = dynamic_thres_percentile
# normalize and unnormalize image functions # normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
@@ -1851,7 +1906,21 @@ class Decoder(BaseGaussianDiffusion):
x_recon = self.predict_start_from_noise(x, t = t, noise = pred) x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised: if clip_denoised:
x_recon.clamp_(-1., 1.) # s is the threshold amount
# static thresholding would just be s = 1
s = 1.
if self.use_dynamic_thres:
s = torch.quantile(
rearrange(x_recon, 'b ... -> b (...)').abs(),
self.dynamic_thres_percentile,
dim = -1
)
s.clamp_(min = 1.)
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
# clip by threshold, depending on whether static or dynamic
x_recon = x_recon.clamp(-s, s) / s
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)

View File

@@ -39,3 +39,37 @@ dataset = ImageEmbeddingDataset(
) )
``` ```
### Diffusion Prior: Prior Embedding Dataset
When training the prior it is much more efficient to work with pre-computed embeddings. The `PriorEmbeddingDataset` class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.
To utilize the `PriorEmbeddingDataset`, all you need to do is make a single call to `get_reader()` which will create `EmbeddingReader` object(s) for you. Afterwards, you can utilize `make_splits()` to cleanly create DataLoader objects from for your training run.
If you are training in a distributed manner, `make_splits()` accepts `rank` and `world_size` arguments to properly distribute to each process. The defaults for these values are `rank=0` and `world_size=1`, so single-process training can safely ignore these parameters.
Usage:
```python
from dalle2_pytorch.dataloaders import get_reader, make_splits
# grab embeddings from some specified location
IMG_URL = "data/img_emb/"
META_URL = "data/meta/"
reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)
# some config for training
TRAIN_ARGS = {
"world_size": 3,
"text_conditioned": True,
"start": 0,
"num_data_points": 10000,
"batch_size": 2,
"train_split": 0.5,
"eval_split": 0.25,
"image_reader": reader,
}
# specifying a rank will handle allocation internally
rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)
rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)
rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)
```

View File

@@ -1,2 +1,2 @@
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits from dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset

View File

@@ -1,180 +0,0 @@
from torch.utils.data import IterableDataset
from torch import from_numpy
from clip import tokenize
from embedding_reader import EmbeddingReader
class PriorEmbeddingLoader(IterableDataset):
def __init__(
self,
text_conditioned: bool,
batch_size: int,
start: int,
stop: int,
image_reader,
text_reader: EmbeddingReader = None,
device: str = "cpu",
) -> None:
super(PriorEmbeddingLoader).__init__()
self.text_conditioned = text_conditioned
if not self.text_conditioned:
self.text_reader = text_reader
self.image_reader = image_reader
self.batch_size = batch_size
self.start = start
self.stop = stop
self.device = device
def __iter__(self):
self.n = 0
loader_args = dict(
batch_size=self.batch_size,
start=self.start,
end=self.stop,
show_progress=False,
)
if self.text_conditioned:
self.loader = self.image_reader(**loader_args)
else:
self.loader = zip(
self.image_reader(**loader_args), self.text_reader(**loader_args)
)
return self
def __next__(self):
try:
return self.get_sample()
except StopIteration:
raise StopIteration
def get_sample(self):
"""
pre-proocess data from either reader into a common format
"""
self.n += 1
if self.text_conditioned:
image_embedding, caption = next(self.loader)
image_embedding = from_numpy(image_embedding).to(self.device)
tokenized_caption = tokenize(
caption["caption"].to_list(), truncate=True
).to(self.device)
return image_embedding, tokenized_caption
else:
(image_embedding, _), (text_embedding, _) = next(self.loader)
image_embedding = from_numpy(image_embedding).to(self.device)
text_embedding = from_numpy(text_embedding).to(self.device)
return image_embedding, text_embedding
def make_splits(
text_conditioned: bool,
batch_size: int,
num_data_points: int,
train_split: float,
eval_split: float,
device: str,
img_url: str,
meta_url: str = None,
txt_url: str = None,
):
assert img_url is not None, "Must supply some image embeddings"
if text_conditioned:
assert meta_url is not None, "Must supply metadata url if text-conditioning"
image_reader = EmbeddingReader(
embeddings_folder=img_url,
file_format="parquet_npy",
meta_columns=["caption"],
metadata_folder=meta_url,
)
# compute split points
if num_data_points > image_reader.count:
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
num_data_points = image_reader.count
train_set_size = int(train_split * num_data_points)
eval_set_size = int(eval_split * num_data_points)
eval_stop = int(train_set_size + eval_set_size)
train_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=0,
stop=train_set_size,
device=device,
)
eval_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=train_set_size,
stop=eval_stop,
device=device,
)
test_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
batch_size=batch_size,
start=eval_stop,
stop=int(num_data_points),
device=device,
)
else:
assert (
txt_url is not None
), "Must supply text embedding url if not text-conditioning"
image_reader = EmbeddingReader(img_url, file_format="npy")
text_reader = EmbeddingReader(txt_url, file_format="npy")
# compute split points
if num_data_points > image_reader.count:
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
num_data_points = image_reader.count
train_set_size = int(train_split * num_data_points)
eval_set_size = int(eval_split * num_data_points)
eval_stop = int(train_set_size + eval_set_size)
train_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=0,
stop=train_set_size,
device=device,
)
eval_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=train_set_size,
stop=eval_stop,
device=device,
)
test_loader = PriorEmbeddingLoader(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
batch_size=batch_size,
start=eval_stop,
stop=int(num_data_points),
device=device,
)
return train_loader, eval_loader, test_loader

View File

@@ -0,0 +1,273 @@
from math import ceil
from clip import tokenize
from embedding_reader import EmbeddingReader
from torch import from_numpy
from torch.utils.data import IterableDataset, DataLoader
class PriorEmbeddingDataset(IterableDataset):
"""
PriorEmbeddingDataset is a wrapper of EmbeddingReader.
It enables one to simplify the logic necessary to yield samples from
the different EmbeddingReader configurations available.
"""
def __init__(
self,
text_conditioned: bool,
batch_size: int,
start: int,
stop: int,
image_reader,
text_reader: EmbeddingReader = None,
) -> None:
super(PriorEmbeddingDataset).__init__()
self.text_conditioned = text_conditioned
if not self.text_conditioned:
self.text_reader = text_reader
self.image_reader = image_reader
self.start = start
self.stop = stop
self.batch_size = batch_size
def __len__(self):
return self.stop - self.start
def __iter__(self):
# D.R.Y loader args
loader_args = dict(
batch_size=self.batch_size,
start=self.start,
end=self.stop,
show_progress=False,
)
# if the data requested is text conditioned, only load images
if self.text_conditioned:
self.loader = self.image_reader(**loader_args)
# otherwise, include text embeddings and bypass metadata
else:
self.loader = zip(
self.image_reader(**loader_args), self.text_reader(**loader_args)
)
# return the data loader in its formatted state
return self
def __next__(self):
try:
return self.get_sample()
except StopIteration:
raise StopIteration
def __str__(self):
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
def get_sample(self):
"""
pre-proocess data from either reader into a common format
"""
if self.text_conditioned:
image_embedding, caption = next(self.loader)
image_embedding = from_numpy(image_embedding)
tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True)
return image_embedding, tokenized_caption
else:
(image_embedding, _), (text_embedding, _) = next(self.loader)
image_embedding = from_numpy(image_embedding)
text_embedding = from_numpy(text_embedding)
return image_embedding, text_embedding
# helper functions
def distribute_to_rank(start, stop, rank, world_size):
"""
Distribute data to each rank given the world size.
Return:
- New start and stop points for this rank.
"""
num_samples = int(stop - start)
per_rank = int(ceil((num_samples) / float(world_size)))
assert (
per_rank > 0
), f"Number of samples per rank must be larger than 0, (found: {per_rank})"
rank_start = start + rank * per_rank
rank_stop = min(rank_start + per_rank, stop)
new_length = rank_stop - rank_start
assert (
new_length > 0
), "Calculated start and stop points result in a length of zero for this rank."
return rank_start, rank_stop
def get_reader(
text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None
):
"""
Create an EmbeddingReader object from the specified URLs
get_reader() will always expect a url to image embeddings.
If text-conditioned, it will also expect a meta_url for the captions.
Otherwise, it will need txt_url for the matching text embeddings.
Returns an image_reader object if text-conditioned.
Otherwise it returns both an image_reader and a text_reader
"""
assert img_url is not None, "Must supply a image url"
if text_conditioned:
assert meta_url is not None, "Must supply meta url if text-conditioned"
image_reader = EmbeddingReader(
embeddings_folder=img_url,
file_format="parquet_npy",
# will assume the caption column exists and is the only one requested
meta_columns=["caption"],
metadata_folder=meta_url,
)
return image_reader
# otherwise we will require text embeddings as well and return two readers
assert (
txt_url is not None
), "Must supply text embedding url if not text-conditioning"
image_reader = EmbeddingReader(img_url, file_format="npy")
text_reader = EmbeddingReader(txt_url, file_format="npy")
return image_reader, text_reader
def make_splits(
text_conditioned: bool,
batch_size: int,
num_data_points: int,
train_split: float,
eval_split: float,
image_reader: EmbeddingReader,
text_reader: EmbeddingReader = None,
start=0,
rank=0,
world_size=1,
):
"""
Split an embedding reader object as needed.
NOTE: make_splits() will infer the test set size from your train and eval.
Input:
- text_conditioned: whether to prepare text-conditioned training data
- batch_size: the batch size for a single gpu
- num_data_points: the total number of data points you wish to train on
- train_split: the percentage of data you wish to train on
- eval_split: the percentage of data you wish to validate on
- image_reader: the image_reader you wish to split
- text_reader: the text_reader you want to split (if !text_conditioned)
- start: the starting point within your dataset
- rank: the rank of your worker
- world_size: the total world size of your distributed training run
Returns:
- PyTorch Dataloaders that yield tuples of (img, txt) data.
"""
assert start < image_reader.count, "start position cannot exceed reader count."
# verify that the num_data_points does not exceed the max points
if num_data_points > (image_reader.count - start):
print(
"Specified count is larger than what's available...defaulting to reader's count."
)
num_data_points = image_reader.count
# compute split points
train_set_size = int(train_split * num_data_points)
eval_set_size = int(eval_split * num_data_points)
eval_start = train_set_size
eval_stop = int(eval_start + eval_set_size)
assert (
train_split + eval_split
) < 1.0, "Specified train and eval split is too large to infer a test split."
# distribute to rank
rank_train_start, rank_train_stop = distribute_to_rank(
start, train_set_size, rank, world_size
)
rank_eval_start, rank_eval_stop = distribute_to_rank(
train_set_size, eval_stop, rank, world_size
)
rank_test_start, rank_test_stop = distribute_to_rank(
eval_stop, num_data_points, rank, world_size
)
# wrap up splits into a dict
train_split_args = dict(
start=rank_train_start, stop=rank_train_stop, batch_size=batch_size
)
eval_split_args = dict(
start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size
)
test_split_args = dict(
start=rank_test_start, stop=rank_test_stop, batch_size=batch_size
)
if text_conditioned:
# add the text-conditioned args to a unified dict
reader_args = dict(
text_conditioned=text_conditioned,
image_reader=image_reader,
)
train_split_args = dict(**reader_args, **train_split_args)
eval_split_args = dict(**reader_args, **eval_split_args)
test_split_args = dict(**reader_args, **test_split_args)
train = PriorEmbeddingDataset(**train_split_args)
val = PriorEmbeddingDataset(**eval_split_args)
test = PriorEmbeddingDataset(**test_split_args)
else:
# add the non-conditioned args to a unified dict
reader_args = dict(
text_conditioned=text_conditioned,
image_reader=image_reader,
text_reader=text_reader,
)
train_split_args = dict(**reader_args, **train_split_args)
eval_split_args = dict(**reader_args, **eval_split_args)
test_split_args = dict(**reader_args, **test_split_args)
train = PriorEmbeddingDataset(**train_split_args)
val = PriorEmbeddingDataset(**eval_split_args)
test = PriorEmbeddingDataset(**test_split_args)
# true batch size is specifed in the PriorEmbeddingDataset
train_loader = DataLoader(train, batch_size=None)
eval_loader = DataLoader(val, batch_size=None)
test_loader = DataLoader(test, batch_size=None)
return train_loader, eval_loader, test_loader

View File

@@ -12,6 +12,7 @@ def get_optimizer(
betas = (0.9, 0.999), betas = (0.9, 0.999),
eps = 1e-8, eps = 1e-8,
filter_by_requires_grad = False, filter_by_requires_grad = False,
group_wd_params = True,
**kwargs **kwargs
): ):
if filter_by_requires_grad: if filter_by_requires_grad:
@@ -20,12 +21,12 @@ def get_optimizer(
if wd == 0: if wd == 0:
return Adam(params, lr = lr, betas = betas, eps = eps) return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params) if group_wd_params:
wd_params, no_wd_params = separate_weight_decayable_params(params) wd_params, no_wd_params = separate_weight_decayable_params(params)
param_groups = [ params = [
{'params': list(wd_params)}, {'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0}, {'params': list(no_wd_params), 'weight_decay': 0},
] ]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps) return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -3,7 +3,18 @@ from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
from dalle2_pytorch.dalle2_pytorch import Unet, Decoder, DiffusionPrior, DiffusionPriorNetwork from x_clip import CLIP as XCLIP
from coca_pytorch import CoCa
from dalle2_pytorch.dalle2_pytorch import (
CoCaAdapter,
OpenAIClipAdapter,
Unet,
Decoder,
DiffusionPrior,
DiffusionPriorNetwork,
XClipAdapter,
)
# helper functions # helper functions
@@ -16,7 +27,47 @@ def default(val, d):
def ListOrTuple(inner_type): def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]] return Union[List[inner_type], Tuple[inner_type]]
# pydantic classes def SingularOrIterable(inner_type):
return Union[inner_type, ListOrTuple(inner_type)]
# general pydantic classes
class TrainSplitConfig(BaseModel):
train: float = 0.75
val: float = 0.15
test: float = 0.1
@root_validator
def validate_all(cls, fields):
actual_sum = sum([*fields.values()])
if actual_sum != 1.:
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
return fields
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
# diffusion prior pydantic classes
class AdapterConfig(BaseModel):
make: str = "openai"
model: str = "ViT-L/14"
base_model_kwargs: Dict[str, Any] = None
def create(self):
if self.make == "openai":
return OpenAIClipAdapter(self.model)
elif self.make == "x-clip":
return XClipAdapter(XCLIP(**self.base_model_kwargs))
elif self.make == "coca":
return CoCaAdapter(CoCa(**self.base_model_kwargs))
else:
raise AttributeError("No adapter with that name is available.")
class DiffusionPriorNetworkConfig(BaseModel): class DiffusionPriorNetworkConfig(BaseModel):
dim: int dim: int
@@ -35,8 +86,12 @@ class DiffusionPriorNetworkConfig(BaseModel):
normformer: bool = False normformer: bool = False
rotary_emb: bool = True rotary_emb: bool = True
def create(self):
kwargs = self.dict()
return DiffusionPriorNetwork(**kwargs)
class DiffusionPriorConfig(BaseModel): class DiffusionPriorConfig(BaseModel):
# only clip-less diffusion prior config for now clip: AdapterConfig = None
net: DiffusionPriorNetworkConfig net: DiffusionPriorNetworkConfig
image_embed_dim: int image_embed_dim: int
image_size: int image_size: int
@@ -46,15 +101,59 @@ class DiffusionPriorConfig(BaseModel):
loss_type: str = 'l2' loss_type: str = 'l2'
predict_x_start: bool = True predict_x_start: bool = True
beta_schedule: str = 'cosine' beta_schedule: str = 'cosine'
condition_on_text_encodings: bool = True
def create(self):
kwargs = self.dict()
diffusion_prior_network = DiffusionPriorNetwork(**kwargs.pop('net'))
return DiffusionPrior(net = diffusion_prior_network, **kwargs)
class Config: class Config:
extra = "allow" extra = "allow"
def create(self):
kwargs = self.dict()
has_clip = exists(kwargs.pop('clip'))
kwargs.pop('net')
clip = None
if has_clip:
clip = self.clip.create()
diffusion_prior_network = self.net.create()
return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
class DiffusionPriorTrainConfig(BaseModel):
epochs: int = 1
lr: float = 1.1e-4
wd: float = 6.02e-2
max_grad_norm: float = 0.5
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
save_every: int = 10000 # what steps to save on
class DiffusionPriorDataConfig(BaseModel):
image_url: str # path to embeddings folder
meta_url: str # path to metadata (captions) for images
splits: TrainSplitConfig
batch_size: int = 64
class DiffusionPriorLoadConfig(BaseModel):
source: str = None
resume: bool = False
class TrainDiffusionPriorConfig(BaseModel):
prior: DiffusionPriorConfig
data: DiffusionPriorDataConfig
train: DiffusionPriorTrainConfig
load: DiffusionPriorLoadConfig
tracker: TrackerConfig
@classmethod
def from_json_path(cls, json_path):
with open(json_path) as f:
config = json.load(f)
return cls(**config)
# decoder pydantic classes
class UnetConfig(BaseModel): class UnetConfig(BaseModel):
dim: int dim: int
dim_mults: ListOrTuple(int) dim_mults: ListOrTuple(int)
@@ -94,17 +193,6 @@ class DecoderConfig(BaseModel):
class Config: class Config:
extra = "allow" extra = "allow"
class TrainSplitConfig(BaseModel):
train: float = 0.75
val: float = 0.15
test: float = 0.1
@root_validator
def validate_all(cls, fields):
if sum([*fields.values()]) != 1.:
raise ValueError(f'{fields.keys()} must sum to 1.0')
return fields
class DecoderDataConfig(BaseModel): class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images webdataset_base_url: str # path to a webdataset with jpg images
embeddings_url: str # path to .npy files with embeddings embeddings_url: str # path to .npy files with embeddings
@@ -137,16 +225,16 @@ class DecoderDataConfig(BaseModel):
class DecoderTrainConfig(BaseModel): class DecoderTrainConfig(BaseModel):
epochs: int = 20 epochs: int = 20
lr: float = 1e-4 lr: SingularOrIterable(float) = 1e-4
wd: float = 0.01 wd: SingularOrIterable(float) = 0.01
max_grad_norm: float = 0.5 max_grad_norm: SingularOrIterable(float) = 0.5
save_every_n_samples: int = 100000 save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
device: str = 'cuda:0' device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation. validation_samples: int = None # Same as above but for validation.
use_ema: bool = True use_ema: bool = True
ema_beta: float = 0.99 ema_beta: float = 0.999
amp: bool = False amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint save_latest: bool = True # Whether to always save the latest checkpoint
@@ -160,14 +248,6 @@ class DecoderEvaluateConfig(BaseModel):
KID: Dict[str, Any] = None KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None LPIPS: Dict[str, Any] = None
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
class DecoderLoadConfig(BaseModel): class DecoderLoadConfig(BaseModel):
source: str = None # Supports file and wandb source: str = None # Supports file and wandb
run_path: str = '' # Used only if source is wandb run_path: str = '' # Used only if source is wandb

View File

@@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.version import __version__
from packaging import version
import numpy as np import numpy as np
@@ -57,8 +59,7 @@ def num_to_groups(num, divisor):
return arr return arr
def get_pkg_version(): def get_pkg_version():
from pkg_resources import get_distribution return __version__
return get_distribution('dalle2_pytorch').version
# decorators # decorators
@@ -254,6 +255,7 @@ class DiffusionPriorTrainer(nn.Module):
eps = 1e-6, eps = 1e-6,
max_grad_norm = None, max_grad_norm = None,
amp = False, amp = False,
group_wd_params = True,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -279,6 +281,7 @@ class DiffusionPriorTrainer(nn.Module):
lr = lr, lr = lr,
wd = wd, wd = wd,
eps = eps, eps = eps,
group_wd_params = group_wd_params,
**kwargs **kwargs
) )
@@ -297,7 +300,7 @@ class DiffusionPriorTrainer(nn.Module):
scaler = self.scaler.state_dict(), scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(), optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(), model = self.diffusion_prior.state_dict(),
version = get_pkg_version(), version = __version__,
step = self.step.item(), step = self.step.item(),
**kwargs **kwargs
) )
@@ -313,8 +316,8 @@ class DiffusionPriorTrainer(nn.Module):
loaded_obj = torch.load(str(path)) loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']: if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}') print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict) self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
@@ -410,6 +413,7 @@ class DecoderTrainer(nn.Module):
eps = 1e-8, eps = 1e-8,
max_grad_norm = 0.5, max_grad_norm = 0.5,
amp = False, amp = False,
group_wd_params = True,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -435,6 +439,7 @@ class DecoderTrainer(nn.Module):
lr = unet_lr, lr = unet_lr,
wd = unet_wd, wd = unet_wd,
eps = unet_eps, eps = unet_eps,
group_wd_params = group_wd_params,
**kwargs **kwargs
) )
@@ -459,7 +464,7 @@ class DecoderTrainer(nn.Module):
save_obj = dict( save_obj = dict(
model = self.decoder.state_dict(), model = self.decoder.state_dict(),
version = get_pkg_version(), version = __version__,
step = self.step.item(), step = self.step.item(),
**kwargs **kwargs
) )
@@ -482,7 +487,7 @@ class DecoderTrainer(nn.Module):
loaded_obj = torch.load(str(path)) loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']: if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}') print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict) self.decoder.load_state_dict(loaded_obj['model'], strict = strict)

View File

@@ -0,0 +1 @@
__version__ = '0.6.2'

View File

@@ -1,4 +1,5 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
exec(open('dalle2_pytorch/version.py').read())
setup( setup(
name = 'dalle2-pytorch', name = 'dalle2-pytorch',
@@ -10,7 +11,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.4.10', version = __version__,
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',

View File

@@ -347,7 +347,7 @@ def train(
# Compute evaluation metrics # Compute evaluation metrics
if exists(evaluate_config): if exists(evaluate_config):
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config) evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
tracker.log(evaluation, step=step, verbose=True) tracker.log(evaluation, step=step, verbose=True)
# Generate sample images # Generate sample images

View File

@@ -7,15 +7,13 @@ import torch
import clip import clip
from torch import nn from torch import nn
from dalle2_pytorch.dataloaders import make_splits from dalle2_pytorch.dataloaders import make_splits, get_reader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from dalle2_pytorch.utils import Timer, print_ribbon from dalle2_pytorch.utils import Timer, print_ribbon
from embedding_reader import EmbeddingReader
from tqdm import tqdm from tqdm import tqdm
# constants # constants
@@ -31,7 +29,7 @@ def exists(val):
# functions # functions
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"): def eval_model(model, dataloader, text_conditioned, loss_type, device, phase="Validation",):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
@@ -39,6 +37,8 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
total_samples = 0. total_samples = 0.
for image_embeddings, text_data in tqdm(dataloader): for image_embeddings, text_data in tqdm(dataloader):
image_embeddings = image_embeddings.to(device)
text_data = text_data.to(device)
batches = image_embeddings.shape[0] batches = image_embeddings.shape[0]
@@ -57,12 +57,14 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
tracker.log({f'{phase} {loss_type}': avg_loss}) tracker.log({f'{phase} {loss_type}': avg_loss})
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned): def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
diffusion_prior.eval() diffusion_prior.eval()
cos = nn.CosineSimilarity(dim=1, eps=1e-6) cos = nn.CosineSimilarity(dim=1, eps=1e-6)
for test_image_embeddings, text_data in tqdm(dataloader): for test_image_embeddings, text_data in tqdm(dataloader):
test_image_embeddings = test_image_embeddings.to(device)
text_data = text_data.to(device)
# we are text conditioned, we produce an embedding from the tokenized text # we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned: if text_conditioned:
@@ -296,15 +298,31 @@ def train(
# Utilize wrapper to abstract away loader logic # Utilize wrapper to abstract away loader logic
print_ribbon("Downloading Embeddings") print_ribbon("Downloading Embeddings")
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points, reader_args = dict(text_conditioned=dp_condition_on_text_encodings, img_url=image_embed_url)
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
if dp_condition_on_text_encodings: if dp_condition_on_text_encodings:
loader_args = dict(**loader_args, meta_url=meta_url) reader_args = dict(**reader_args, meta_url=meta_url)
img_reader = get_reader(**reader_args)
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=dp_condition_on_text_encodings,
batch_size=batch_size,
num_data_points=num_data_points,
train_split=train_percent,
eval_split=val_percent,
image_reader=img_reader
)
else: else:
loader_args = dict(**loader_args, txt_url=text_embed_url) reader_args = dict(**reader_args, txt_url=text_embed_url)
img_reader, txt_reader = get_reader(**reader_args)
train_loader, eval_loader, test_loader = make_splits(**loader_args) train_loader, eval_loader, test_loader = make_splits(
text_conditioned=dp_condition_on_text_encodings,
batch_size=batch_size,
num_data_points=num_data_points,
train_split=train_percent,
eval_split=val_percent,
image_reader=img_reader,
text_reader=txt_reader
)
### Training code ### ### Training code ###
@@ -315,9 +333,11 @@ def train(
for _ in range(epochs): for _ in range(epochs):
for image, text in tqdm(train_loader): for image, text in tqdm(train_loader):
diffusion_prior.train() diffusion_prior.train()
image = image.to(device)
text = text.to(device)
input_args = dict(image_embed=image) input_args = dict(image_embed=image)
if dp_condition_on_text_encodings: if dp_condition_on_text_encodings:
input_args = dict(**input_args, text = text) input_args = dict(**input_args, text = text)
@@ -350,9 +370,9 @@ def train(
# Use NUM_TEST_EMBEDDINGS samples from the test set each time # Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model # Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0: if(step % REPORT_METRICS_EVERY) == 0:
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings) report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings, device=device)
### Evaluate model(validation run) ### ### Evaluate model(validation run) ###
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation") eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation", device=device)
step += 1 step += 1
trainer.update() trainer.update()