mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0bed30a84 | ||
|
|
387c5bf774 | ||
|
|
a13d2d89c5 | ||
|
|
44d4b1bba9 | ||
|
|
f12a7589c5 | ||
|
|
b8af2210df | ||
|
|
f4fe6c570d | ||
|
|
645e207441 | ||
|
|
00743b3a0b | ||
|
|
01589aff6a | ||
|
|
7ecfd76cc0 |
10
README.md
10
README.md
@@ -24,6 +24,8 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
|
||||
|
||||
*ongoing at 21k steps*
|
||||
|
||||
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
|
||||
|
||||
## Pre-Trained Models
|
||||
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
|
||||
@@ -1048,6 +1050,7 @@ This library would not have gotten to this working state without the help of
|
||||
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
||||
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
||||
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
|
||||
|
||||
... and many others. Thank you! 🙏
|
||||
|
||||
@@ -1091,7 +1094,7 @@ This library would not have gotten to this working state without the help of
|
||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
@@ -1140,8 +1143,9 @@ This library would not have gotten to this working state without the help of
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022}
|
||||
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022},
|
||||
url = {https://arxiv.org/abs/2204.01697}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -1343,9 +1343,11 @@ class Unet(nn.Module):
|
||||
cond_on_text_encodings = False,
|
||||
max_text_len = 256,
|
||||
cond_on_image_embeds = False,
|
||||
add_image_embeds_to_time = True, # alerted by @mhh0318 to a phrase in the paper - "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and adding CLIP embeddings to the existing timestep embedding"
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7,
|
||||
resnet_groups = 8,
|
||||
num_resnet_blocks = 1,
|
||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||
cross_embed_downsample = False,
|
||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||
@@ -1395,11 +1397,16 @@ class Unet(nn.Module):
|
||||
nn.Linear(time_cond_dim, time_cond_dim)
|
||||
)
|
||||
|
||||
self.image_to_cond = nn.Sequential(
|
||||
self.image_to_tokens = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
self.to_image_hiddens = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, time_cond_dim),
|
||||
nn.GELU()
|
||||
) if cond_on_image_embeds and add_image_embeds_to_time else None
|
||||
|
||||
self.norm_cond = nn.LayerNorm(cond_dim)
|
||||
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
||||
|
||||
@@ -1431,6 +1438,7 @@ class Unet(nn.Module):
|
||||
# resnet block klass
|
||||
|
||||
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
||||
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
|
||||
|
||||
assert len(resnet_groups) == len(in_out)
|
||||
|
||||
@@ -1446,7 +1454,7 @@ class Unet(nn.Module):
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks)):
|
||||
is_first = ind == 0
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
@@ -1454,7 +1462,7 @@ class Unet(nn.Module):
|
||||
self.downs.append(nn.ModuleList([
|
||||
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
downsample_klass(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
@@ -1464,14 +1472,14 @@ class Unet(nn.Module):
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||
|
||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
|
||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
||||
is_last = ind >= (num_resolutions - 2)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
|
||||
@@ -1556,6 +1564,13 @@ class Unet(nn.Module):
|
||||
time_tokens = self.to_time_tokens(time_hiddens)
|
||||
t = self.to_time_cond(time_hiddens)
|
||||
|
||||
# image embedding to be summed to time embedding
|
||||
# discovered by @mhh0318 in the paper
|
||||
|
||||
if exists(image_embed) and exists(self.to_image_hiddens):
|
||||
image_hiddens = self.to_image_hiddens(image_embed)
|
||||
t = t + image_hiddens
|
||||
|
||||
# conditional dropout
|
||||
|
||||
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
||||
@@ -1569,7 +1584,7 @@ class Unet(nn.Module):
|
||||
image_tokens = None
|
||||
|
||||
if self.cond_on_image_embeds:
|
||||
image_tokens = self.image_to_cond(image_embed)
|
||||
image_tokens = self.image_to_tokens(image_embed)
|
||||
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
||||
|
||||
image_tokens = torch.where(
|
||||
@@ -1628,10 +1643,13 @@ class Unet(nn.Module):
|
||||
|
||||
hiddens = []
|
||||
|
||||
for block1, sparse_attn, block2, downsample in self.downs:
|
||||
x = block1(x, c, t)
|
||||
for init_block, sparse_attn, resnet_blocks, downsample in self.downs:
|
||||
x = init_block(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = block2(x, c, t)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = resnet_block(x, c, t)
|
||||
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
@@ -1642,11 +1660,14 @@ class Unet(nn.Module):
|
||||
|
||||
x = self.mid_block2(x, mid_c, t)
|
||||
|
||||
for block1, sparse_attn, block2, upsample in self.ups:
|
||||
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = block1(x, c, t)
|
||||
x = init_block(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = block2(x, c, t)
|
||||
|
||||
for resnet_block in resnet_blocks:
|
||||
x = resnet_block(x, c, t)
|
||||
|
||||
x = upsample(x)
|
||||
|
||||
return self.final_conv(x)
|
||||
|
||||
@@ -4,7 +4,7 @@ In order to make loading data simple and efficient, we include some general data
|
||||
### Decoder: Image Embedding Dataset
|
||||
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
|
||||
|
||||
Generating a dataset of this type:
|
||||
Generating a dataset of this type:
|
||||
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
|
||||
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
|
||||
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
|
||||
@@ -39,3 +39,37 @@ dataset = ImageEmbeddingDataset(
|
||||
)
|
||||
```
|
||||
|
||||
### Diffusion Prior: Prior Embedding Dataset
|
||||
When training the prior it is much more efficient to work with pre-computed embeddings. The `PriorEmbeddingDataset` class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.
|
||||
|
||||
To utilize the `PriorEmbeddingDataset`, all you need to do is make a single call to `get_reader()` which will create `EmbeddingReader` object(s) for you. Afterwards, you can utilize `make_splits()` to cleanly create DataLoader objects from for your training run.
|
||||
|
||||
If you are training in a distributed manner, `make_splits()` accepts `rank` and `world_size` arguments to properly distribute to each process. The defaults for these values are `rank=0` and `world_size=1`, so single-process training can safely ignore these parameters.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
||||
|
||||
# grab embeddings from some specified location
|
||||
IMG_URL = "data/img_emb/"
|
||||
META_URL = "data/meta/"
|
||||
|
||||
reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)
|
||||
|
||||
# some config for training
|
||||
TRAIN_ARGS = {
|
||||
"world_size": 3,
|
||||
"text_conditioned": True,
|
||||
"start": 0,
|
||||
"num_data_points": 10000,
|
||||
"batch_size": 2,
|
||||
"train_split": 0.5,
|
||||
"eval_split": 0.25,
|
||||
"image_reader": reader,
|
||||
}
|
||||
|
||||
# specifying a rank will handle allocation internally
|
||||
rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)
|
||||
rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)
|
||||
rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)
|
||||
```
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||
from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits
|
||||
from dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch import from_numpy
|
||||
from clip import tokenize
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
|
||||
class PriorEmbeddingLoader(IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
start: int,
|
||||
stop: int,
|
||||
image_reader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
device: str = "cpu",
|
||||
) -> None:
|
||||
super(PriorEmbeddingLoader).__init__()
|
||||
|
||||
self.text_conditioned = text_conditioned
|
||||
|
||||
if not self.text_conditioned:
|
||||
self.text_reader = text_reader
|
||||
|
||||
self.image_reader = image_reader
|
||||
self.batch_size = batch_size
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.device = device
|
||||
|
||||
def __iter__(self):
|
||||
self.n = 0
|
||||
loader_args = dict(
|
||||
batch_size=self.batch_size,
|
||||
start=self.start,
|
||||
end=self.stop,
|
||||
show_progress=False,
|
||||
)
|
||||
if self.text_conditioned:
|
||||
self.loader = self.image_reader(**loader_args)
|
||||
else:
|
||||
self.loader = zip(
|
||||
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
||||
)
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return self.get_sample()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
||||
def get_sample(self):
|
||||
"""
|
||||
pre-proocess data from either reader into a common format
|
||||
"""
|
||||
self.n += 1
|
||||
|
||||
if self.text_conditioned:
|
||||
image_embedding, caption = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
||||
tokenized_caption = tokenize(
|
||||
caption["caption"].to_list(), truncate=True
|
||||
).to(self.device)
|
||||
|
||||
return image_embedding, tokenized_caption
|
||||
|
||||
else:
|
||||
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
||||
text_embedding = from_numpy(text_embedding).to(self.device)
|
||||
|
||||
return image_embedding, text_embedding
|
||||
|
||||
|
||||
def make_splits(
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
num_data_points: int,
|
||||
train_split: float,
|
||||
eval_split: float,
|
||||
device: str,
|
||||
img_url: str,
|
||||
meta_url: str = None,
|
||||
txt_url: str = None,
|
||||
):
|
||||
|
||||
assert img_url is not None, "Must supply some image embeddings"
|
||||
|
||||
if text_conditioned:
|
||||
assert meta_url is not None, "Must supply metadata url if text-conditioning"
|
||||
image_reader = EmbeddingReader(
|
||||
embeddings_folder=img_url,
|
||||
file_format="parquet_npy",
|
||||
meta_columns=["caption"],
|
||||
metadata_folder=meta_url,
|
||||
)
|
||||
|
||||
# compute split points
|
||||
if num_data_points > image_reader.count:
|
||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
||||
num_data_points = image_reader.count
|
||||
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_stop = int(train_set_size + eval_set_size)
|
||||
|
||||
train_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=0,
|
||||
stop=train_set_size,
|
||||
device=device,
|
||||
)
|
||||
eval_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=train_set_size,
|
||||
stop=eval_stop,
|
||||
device=device,
|
||||
)
|
||||
test_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
batch_size=batch_size,
|
||||
start=eval_stop,
|
||||
stop=int(num_data_points),
|
||||
device=device,
|
||||
)
|
||||
|
||||
else:
|
||||
assert (
|
||||
txt_url is not None
|
||||
), "Must supply text embedding url if not text-conditioning"
|
||||
|
||||
image_reader = EmbeddingReader(img_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
||||
|
||||
# compute split points
|
||||
if num_data_points > image_reader.count:
|
||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
||||
num_data_points = image_reader.count
|
||||
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_stop = int(train_set_size + eval_set_size)
|
||||
|
||||
train_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=0,
|
||||
stop=train_set_size,
|
||||
device=device,
|
||||
)
|
||||
eval_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=train_set_size,
|
||||
stop=eval_stop,
|
||||
device=device,
|
||||
)
|
||||
test_loader = PriorEmbeddingLoader(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
batch_size=batch_size,
|
||||
start=eval_stop,
|
||||
stop=int(num_data_points),
|
||||
device=device,
|
||||
)
|
||||
|
||||
return train_loader, eval_loader, test_loader
|
||||
273
dalle2_pytorch/dataloaders/prior_loader.py
Normal file
273
dalle2_pytorch/dataloaders/prior_loader.py
Normal file
@@ -0,0 +1,273 @@
|
||||
from math import ceil
|
||||
from clip import tokenize
|
||||
from embedding_reader import EmbeddingReader
|
||||
from torch import from_numpy
|
||||
from torch.utils.data import IterableDataset, DataLoader
|
||||
|
||||
|
||||
class PriorEmbeddingDataset(IterableDataset):
|
||||
"""
|
||||
PriorEmbeddingDataset is a wrapper of EmbeddingReader.
|
||||
|
||||
It enables one to simplify the logic necessary to yield samples from
|
||||
the different EmbeddingReader configurations available.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
start: int,
|
||||
stop: int,
|
||||
image_reader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
) -> None:
|
||||
super(PriorEmbeddingDataset).__init__()
|
||||
|
||||
self.text_conditioned = text_conditioned
|
||||
|
||||
if not self.text_conditioned:
|
||||
self.text_reader = text_reader
|
||||
|
||||
self.image_reader = image_reader
|
||||
self.start = start
|
||||
self.stop = stop
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __len__(self):
|
||||
return self.stop - self.start
|
||||
|
||||
def __iter__(self):
|
||||
# D.R.Y loader args
|
||||
loader_args = dict(
|
||||
batch_size=self.batch_size,
|
||||
start=self.start,
|
||||
end=self.stop,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
# if the data requested is text conditioned, only load images
|
||||
if self.text_conditioned:
|
||||
self.loader = self.image_reader(**loader_args)
|
||||
# otherwise, include text embeddings and bypass metadata
|
||||
else:
|
||||
self.loader = zip(
|
||||
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
||||
)
|
||||
|
||||
# return the data loader in its formatted state
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return self.get_sample()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
||||
def __str__(self):
|
||||
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
|
||||
|
||||
def get_sample(self):
|
||||
"""
|
||||
pre-proocess data from either reader into a common format
|
||||
"""
|
||||
if self.text_conditioned:
|
||||
image_embedding, caption = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding)
|
||||
tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True)
|
||||
|
||||
return image_embedding, tokenized_caption
|
||||
|
||||
else:
|
||||
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
||||
|
||||
image_embedding = from_numpy(image_embedding)
|
||||
text_embedding = from_numpy(text_embedding)
|
||||
|
||||
return image_embedding, text_embedding
|
||||
|
||||
|
||||
# helper functions
|
||||
|
||||
|
||||
def distribute_to_rank(start, stop, rank, world_size):
|
||||
"""
|
||||
Distribute data to each rank given the world size.
|
||||
|
||||
Return:
|
||||
- New start and stop points for this rank.
|
||||
"""
|
||||
num_samples = int(stop - start)
|
||||
|
||||
per_rank = int(ceil((num_samples) / float(world_size)))
|
||||
|
||||
assert (
|
||||
per_rank > 0
|
||||
), f"Number of samples per rank must be larger than 0, (found: {per_rank})"
|
||||
|
||||
rank_start = start + rank * per_rank
|
||||
|
||||
rank_stop = min(rank_start + per_rank, stop)
|
||||
|
||||
new_length = rank_stop - rank_start
|
||||
|
||||
assert (
|
||||
new_length > 0
|
||||
), "Calculated start and stop points result in a length of zero for this rank."
|
||||
|
||||
return rank_start, rank_stop
|
||||
|
||||
|
||||
def get_reader(
|
||||
text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None
|
||||
):
|
||||
"""
|
||||
Create an EmbeddingReader object from the specified URLs
|
||||
|
||||
get_reader() will always expect a url to image embeddings.
|
||||
|
||||
If text-conditioned, it will also expect a meta_url for the captions.
|
||||
Otherwise, it will need txt_url for the matching text embeddings.
|
||||
|
||||
Returns an image_reader object if text-conditioned.
|
||||
Otherwise it returns both an image_reader and a text_reader
|
||||
"""
|
||||
|
||||
assert img_url is not None, "Must supply a image url"
|
||||
|
||||
if text_conditioned:
|
||||
assert meta_url is not None, "Must supply meta url if text-conditioned"
|
||||
|
||||
image_reader = EmbeddingReader(
|
||||
embeddings_folder=img_url,
|
||||
file_format="parquet_npy",
|
||||
# will assume the caption column exists and is the only one requested
|
||||
meta_columns=["caption"],
|
||||
metadata_folder=meta_url,
|
||||
)
|
||||
|
||||
return image_reader
|
||||
|
||||
# otherwise we will require text embeddings as well and return two readers
|
||||
assert (
|
||||
txt_url is not None
|
||||
), "Must supply text embedding url if not text-conditioning"
|
||||
|
||||
image_reader = EmbeddingReader(img_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
||||
|
||||
return image_reader, text_reader
|
||||
|
||||
|
||||
def make_splits(
|
||||
text_conditioned: bool,
|
||||
batch_size: int,
|
||||
num_data_points: int,
|
||||
train_split: float,
|
||||
eval_split: float,
|
||||
image_reader: EmbeddingReader,
|
||||
text_reader: EmbeddingReader = None,
|
||||
start=0,
|
||||
rank=0,
|
||||
world_size=1,
|
||||
):
|
||||
"""
|
||||
Split an embedding reader object as needed.
|
||||
|
||||
NOTE: make_splits() will infer the test set size from your train and eval.
|
||||
|
||||
Input:
|
||||
- text_conditioned: whether to prepare text-conditioned training data
|
||||
- batch_size: the batch size for a single gpu
|
||||
- num_data_points: the total number of data points you wish to train on
|
||||
- train_split: the percentage of data you wish to train on
|
||||
- eval_split: the percentage of data you wish to validate on
|
||||
- image_reader: the image_reader you wish to split
|
||||
- text_reader: the text_reader you want to split (if !text_conditioned)
|
||||
- start: the starting point within your dataset
|
||||
- rank: the rank of your worker
|
||||
- world_size: the total world size of your distributed training run
|
||||
|
||||
Returns:
|
||||
- PyTorch Dataloaders that yield tuples of (img, txt) data.
|
||||
"""
|
||||
|
||||
assert start < image_reader.count, "start position cannot exceed reader count."
|
||||
|
||||
# verify that the num_data_points does not exceed the max points
|
||||
if num_data_points > (image_reader.count - start):
|
||||
print(
|
||||
"Specified count is larger than what's available...defaulting to reader's count."
|
||||
)
|
||||
num_data_points = image_reader.count
|
||||
|
||||
# compute split points
|
||||
train_set_size = int(train_split * num_data_points)
|
||||
eval_set_size = int(eval_split * num_data_points)
|
||||
eval_start = train_set_size
|
||||
eval_stop = int(eval_start + eval_set_size)
|
||||
|
||||
assert (
|
||||
train_split + eval_split
|
||||
) < 1.0, "Specified train and eval split is too large to infer a test split."
|
||||
|
||||
# distribute to rank
|
||||
rank_train_start, rank_train_stop = distribute_to_rank(
|
||||
start, train_set_size, rank, world_size
|
||||
)
|
||||
rank_eval_start, rank_eval_stop = distribute_to_rank(
|
||||
train_set_size, eval_stop, rank, world_size
|
||||
)
|
||||
rank_test_start, rank_test_stop = distribute_to_rank(
|
||||
eval_stop, num_data_points, rank, world_size
|
||||
)
|
||||
|
||||
# wrap up splits into a dict
|
||||
train_split_args = dict(
|
||||
start=rank_train_start, stop=rank_train_stop, batch_size=batch_size
|
||||
)
|
||||
eval_split_args = dict(
|
||||
start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size
|
||||
)
|
||||
test_split_args = dict(
|
||||
start=rank_test_start, stop=rank_test_stop, batch_size=batch_size
|
||||
)
|
||||
|
||||
if text_conditioned:
|
||||
# add the text-conditioned args to a unified dict
|
||||
reader_args = dict(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
)
|
||||
|
||||
train_split_args = dict(**reader_args, **train_split_args)
|
||||
eval_split_args = dict(**reader_args, **eval_split_args)
|
||||
test_split_args = dict(**reader_args, **test_split_args)
|
||||
|
||||
train = PriorEmbeddingDataset(**train_split_args)
|
||||
val = PriorEmbeddingDataset(**eval_split_args)
|
||||
test = PriorEmbeddingDataset(**test_split_args)
|
||||
|
||||
else:
|
||||
# add the non-conditioned args to a unified dict
|
||||
reader_args = dict(
|
||||
text_conditioned=text_conditioned,
|
||||
image_reader=image_reader,
|
||||
text_reader=text_reader,
|
||||
)
|
||||
|
||||
train_split_args = dict(**reader_args, **train_split_args)
|
||||
eval_split_args = dict(**reader_args, **eval_split_args)
|
||||
test_split_args = dict(**reader_args, **test_split_args)
|
||||
|
||||
train = PriorEmbeddingDataset(**train_split_args)
|
||||
val = PriorEmbeddingDataset(**eval_split_args)
|
||||
test = PriorEmbeddingDataset(**test_split_args)
|
||||
|
||||
# true batch size is specifed in the PriorEmbeddingDataset
|
||||
train_loader = DataLoader(train, batch_size=None)
|
||||
eval_loader = DataLoader(val, batch_size=None)
|
||||
test_loader = DataLoader(test, batch_size=None)
|
||||
|
||||
return train_loader, eval_loader, test_loader
|
||||
@@ -27,6 +27,9 @@ def default(val, d):
|
||||
def ListOrTuple(inner_type):
|
||||
return Union[List[inner_type], Tuple[inner_type]]
|
||||
|
||||
def SingularOrIterable(inner_type):
|
||||
return Union[inner_type, ListOrTuple(inner_type)]
|
||||
|
||||
# general pydantic classes
|
||||
|
||||
class TrainSplitConfig(BaseModel):
|
||||
@@ -88,7 +91,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
||||
return DiffusionPriorNetwork(**kwargs)
|
||||
|
||||
class DiffusionPriorConfig(BaseModel):
|
||||
clip: AdapterConfig
|
||||
clip: AdapterConfig = None
|
||||
net: DiffusionPriorNetworkConfig
|
||||
image_embed_dim: int
|
||||
image_size: int
|
||||
@@ -105,9 +108,16 @@ class DiffusionPriorConfig(BaseModel):
|
||||
|
||||
def create(self):
|
||||
kwargs = self.dict()
|
||||
clip = AdapterConfig(**kwargs.pop('clip')).create()
|
||||
diffusion_prior_network = DiffusionPriorNetworkConfig(**kwargs.pop('net')).create()
|
||||
return DiffusionPrior(net = diffusion_prior_network, clip=clip, **kwargs)
|
||||
|
||||
has_clip = exists(kwargs.pop('clip'))
|
||||
kwargs.pop('net')
|
||||
|
||||
clip = None
|
||||
if has_clip:
|
||||
clip = self.clip.create()
|
||||
|
||||
diffusion_prior_network = self.net.create()
|
||||
return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
|
||||
|
||||
class DiffusionPriorTrainConfig(BaseModel):
|
||||
epochs: int = 1
|
||||
@@ -215,16 +225,16 @@ class DecoderDataConfig(BaseModel):
|
||||
|
||||
class DecoderTrainConfig(BaseModel):
|
||||
epochs: int = 20
|
||||
lr: float = 1e-4
|
||||
wd: float = 0.01
|
||||
max_grad_norm: float = 0.5
|
||||
lr: SingularOrIterable(float) = 1e-4
|
||||
wd: SingularOrIterable(float) = 0.01
|
||||
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||
save_every_n_samples: int = 100000
|
||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||
device: str = 'cuda:0'
|
||||
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
|
||||
validation_samples: int = None # Same as above but for validation.
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.99
|
||||
ema_beta: float = 0.999
|
||||
amp: bool = False
|
||||
save_all: bool = False # Whether to preserve all checkpoints
|
||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||
|
||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.5.4',
|
||||
version = '0.6.0',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -347,7 +347,7 @@ def train(
|
||||
# Compute evaluation metrics
|
||||
if exists(evaluate_config):
|
||||
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict())
|
||||
tracker.log(evaluation, step=step, verbose=True)
|
||||
|
||||
# Generate sample images
|
||||
|
||||
@@ -7,15 +7,13 @@ import torch
|
||||
import clip
|
||||
from torch import nn
|
||||
|
||||
from dalle2_pytorch.dataloaders import make_splits
|
||||
from dalle2_pytorch.dataloaders import make_splits, get_reader
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
|
||||
|
||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
# constants
|
||||
@@ -31,7 +29,7 @@ def exists(val):
|
||||
|
||||
# functions
|
||||
|
||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
||||
def eval_model(model, dataloader, text_conditioned, loss_type, device, phase="Validation",):
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -39,6 +37,8 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
|
||||
total_samples = 0.
|
||||
|
||||
for image_embeddings, text_data in tqdm(dataloader):
|
||||
image_embeddings = image_embeddings.to(device)
|
||||
text_data = text_data.to(device)
|
||||
|
||||
batches = image_embeddings.shape[0]
|
||||
|
||||
@@ -57,12 +57,14 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
|
||||
|
||||
tracker.log({f'{phase} {loss_type}': avg_loss})
|
||||
|
||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
|
||||
diffusion_prior.eval()
|
||||
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
for test_image_embeddings, text_data in tqdm(dataloader):
|
||||
test_image_embeddings = test_image_embeddings.to(device)
|
||||
text_data = text_data.to(device)
|
||||
|
||||
# we are text conditioned, we produce an embedding from the tokenized text
|
||||
if text_conditioned:
|
||||
@@ -240,7 +242,7 @@ def train(
|
||||
# Training loop
|
||||
# diffusion prior network
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
@@ -249,16 +251,16 @@ def train(
|
||||
ff_dropout = dropout,
|
||||
normformer = dp_normformer
|
||||
)
|
||||
|
||||
|
||||
# Load clip model if text-conditioning
|
||||
if dp_condition_on_text_encodings:
|
||||
clip_adapter = OpenAIClipAdapter(clip)
|
||||
else:
|
||||
clip_adapter = None
|
||||
|
||||
|
||||
# diffusion prior with text embeddings and image embeddings pre-computed
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip_adapter,
|
||||
image_embed_dim = image_embed_dim,
|
||||
@@ -296,28 +298,46 @@ def train(
|
||||
|
||||
# Utilize wrapper to abstract away loader logic
|
||||
print_ribbon("Downloading Embeddings")
|
||||
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
|
||||
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
|
||||
reader_args = dict(text_conditioned=dp_condition_on_text_encodings, img_url=image_embed_url)
|
||||
|
||||
if dp_condition_on_text_encodings:
|
||||
loader_args = dict(**loader_args, meta_url=meta_url)
|
||||
reader_args = dict(**reader_args, meta_url=meta_url)
|
||||
img_reader = get_reader(**reader_args)
|
||||
train_loader, eval_loader, test_loader = make_splits(
|
||||
text_conditioned=dp_condition_on_text_encodings,
|
||||
batch_size=batch_size,
|
||||
num_data_points=num_data_points,
|
||||
train_split=train_percent,
|
||||
eval_split=val_percent,
|
||||
image_reader=img_reader
|
||||
)
|
||||
else:
|
||||
loader_args = dict(**loader_args, txt_url=text_embed_url)
|
||||
|
||||
train_loader, eval_loader, test_loader = make_splits(**loader_args)
|
||||
reader_args = dict(**reader_args, txt_url=text_embed_url)
|
||||
img_reader, txt_reader = get_reader(**reader_args)
|
||||
train_loader, eval_loader, test_loader = make_splits(
|
||||
text_conditioned=dp_condition_on_text_encodings,
|
||||
batch_size=batch_size,
|
||||
num_data_points=num_data_points,
|
||||
train_split=train_percent,
|
||||
eval_split=val_percent,
|
||||
image_reader=img_reader,
|
||||
text_reader=txt_reader
|
||||
)
|
||||
|
||||
### Training code ###
|
||||
|
||||
step = 1
|
||||
step = 1
|
||||
timer = Timer()
|
||||
epochs = num_epochs
|
||||
|
||||
for _ in range(epochs):
|
||||
|
||||
for image, text in tqdm(train_loader):
|
||||
|
||||
diffusion_prior.train()
|
||||
|
||||
|
||||
image = image.to(device)
|
||||
text = text.to(device)
|
||||
|
||||
input_args = dict(image_embed=image)
|
||||
if dp_condition_on_text_encodings:
|
||||
input_args = dict(**input_args, text = text)
|
||||
@@ -350,9 +370,9 @@ def train(
|
||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||
# Get embeddings from the most recently saved model
|
||||
if(step % REPORT_METRICS_EVERY) == 0:
|
||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
|
||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings, device=device)
|
||||
### Evaluate model(validation run) ###
|
||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
|
||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation", device=device)
|
||||
|
||||
step += 1
|
||||
trainer.update()
|
||||
|
||||
Reference in New Issue
Block a user