mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 22:14:20 +01:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b588286288 | ||
|
|
b693e0be03 | ||
|
|
a0bed30a84 | ||
|
|
387c5bf774 | ||
|
|
a13d2d89c5 | ||
|
|
44d4b1bba9 | ||
|
|
f12a7589c5 |
@@ -1094,7 +1094,7 @@ This library would not have gotten to this working state without the help of
|
|||||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697
|
||||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||||
@@ -1144,7 +1144,8 @@ This library would not have gotten to this working state without the help of
|
|||||||
@inproceedings{Tu2022MaxViTMV,
|
@inproceedings{Tu2022MaxViTMV,
|
||||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||||
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||||
year = {2022}
|
year = {2022},
|
||||||
|
url = {https://arxiv.org/abs/2204.01697}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1199,7 +1200,7 @@ This library would not have gotten to this working state without the help of
|
|||||||
```
|
```
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@misc{Saharia2022,https://stability.ai/
|
@misc{Saharia2022,
|
||||||
title = {Imagen: unprecedented photorealism × deep level of language understanding},
|
title = {Imagen: unprecedented photorealism × deep level of language understanding},
|
||||||
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
|
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
|
||||||
year = {2022}
|
year = {2022}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from dalle2_pytorch.version import __version__
|
||||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||||
|
|||||||
@@ -1343,10 +1343,11 @@ class Unet(nn.Module):
|
|||||||
cond_on_text_encodings = False,
|
cond_on_text_encodings = False,
|
||||||
max_text_len = 256,
|
max_text_len = 256,
|
||||||
cond_on_image_embeds = False,
|
cond_on_image_embeds = False,
|
||||||
|
add_image_embeds_to_time = True, # alerted by @mhh0318 to a phrase in the paper - "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and adding CLIP embeddings to the existing timestep embedding"
|
||||||
init_dim = None,
|
init_dim = None,
|
||||||
init_conv_kernel_size = 7,
|
init_conv_kernel_size = 7,
|
||||||
resnet_groups = 8,
|
resnet_groups = 8,
|
||||||
num_resnet_blocks = 1,
|
num_resnet_blocks = 2,
|
||||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||||
cross_embed_downsample = False,
|
cross_embed_downsample = False,
|
||||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||||
@@ -1396,11 +1397,16 @@ class Unet(nn.Module):
|
|||||||
nn.Linear(time_cond_dim, time_cond_dim)
|
nn.Linear(time_cond_dim, time_cond_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.image_to_cond = nn.Sequential(
|
self.image_to_tokens = nn.Sequential(
|
||||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||||
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
|
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
|
||||||
|
|
||||||
|
self.to_image_hiddens = nn.Sequential(
|
||||||
|
nn.Linear(image_embed_dim, time_cond_dim),
|
||||||
|
nn.GELU()
|
||||||
|
) if cond_on_image_embeds and add_image_embeds_to_time else None
|
||||||
|
|
||||||
self.norm_cond = nn.LayerNorm(cond_dim)
|
self.norm_cond = nn.LayerNorm(cond_dim)
|
||||||
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
||||||
|
|
||||||
@@ -1558,6 +1564,13 @@ class Unet(nn.Module):
|
|||||||
time_tokens = self.to_time_tokens(time_hiddens)
|
time_tokens = self.to_time_tokens(time_hiddens)
|
||||||
t = self.to_time_cond(time_hiddens)
|
t = self.to_time_cond(time_hiddens)
|
||||||
|
|
||||||
|
# image embedding to be summed to time embedding
|
||||||
|
# discovered by @mhh0318 in the paper
|
||||||
|
|
||||||
|
if exists(image_embed) and exists(self.to_image_hiddens):
|
||||||
|
image_hiddens = self.to_image_hiddens(image_embed)
|
||||||
|
t = t + image_hiddens
|
||||||
|
|
||||||
# conditional dropout
|
# conditional dropout
|
||||||
|
|
||||||
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
||||||
@@ -1571,7 +1584,7 @@ class Unet(nn.Module):
|
|||||||
image_tokens = None
|
image_tokens = None
|
||||||
|
|
||||||
if self.cond_on_image_embeds:
|
if self.cond_on_image_embeds:
|
||||||
image_tokens = self.image_to_cond(image_embed)
|
image_tokens = self.image_to_tokens(image_embed)
|
||||||
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
||||||
|
|
||||||
image_tokens = torch.where(
|
image_tokens = torch.where(
|
||||||
|
|||||||
@@ -39,3 +39,37 @@ dataset = ImageEmbeddingDataset(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Diffusion Prior: Prior Embedding Dataset
|
||||||
|
When training the prior it is much more efficient to work with pre-computed embeddings. The `PriorEmbeddingDataset` class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.
|
||||||
|
|
||||||
|
To utilize the `PriorEmbeddingDataset`, all you need to do is make a single call to `get_reader()` which will create `EmbeddingReader` object(s) for you. Afterwards, you can utilize `make_splits()` to cleanly create DataLoader objects from for your training run.
|
||||||
|
|
||||||
|
If you are training in a distributed manner, `make_splits()` accepts `rank` and `world_size` arguments to properly distribute to each process. The defaults for these values are `rank=0` and `world_size=1`, so single-process training can safely ignore these parameters.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```python
|
||||||
|
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
||||||
|
|
||||||
|
# grab embeddings from some specified location
|
||||||
|
IMG_URL = "data/img_emb/"
|
||||||
|
META_URL = "data/meta/"
|
||||||
|
|
||||||
|
reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)
|
||||||
|
|
||||||
|
# some config for training
|
||||||
|
TRAIN_ARGS = {
|
||||||
|
"world_size": 3,
|
||||||
|
"text_conditioned": True,
|
||||||
|
"start": 0,
|
||||||
|
"num_data_points": 10000,
|
||||||
|
"batch_size": 2,
|
||||||
|
"train_split": 0.5,
|
||||||
|
"eval_split": 0.25,
|
||||||
|
"image_reader": reader,
|
||||||
|
}
|
||||||
|
|
||||||
|
# specifying a rank will handle allocation internally
|
||||||
|
rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)
|
||||||
|
rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)
|
||||||
|
rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)
|
||||||
|
```
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
|
from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||||
from dalle2_pytorch.dataloaders.embedding_wrapper import make_splits
|
from dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset
|
||||||
|
|||||||
@@ -1,180 +0,0 @@
|
|||||||
from torch.utils.data import IterableDataset
|
|
||||||
from torch import from_numpy
|
|
||||||
from clip import tokenize
|
|
||||||
from embedding_reader import EmbeddingReader
|
|
||||||
|
|
||||||
|
|
||||||
class PriorEmbeddingLoader(IterableDataset):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
text_conditioned: bool,
|
|
||||||
batch_size: int,
|
|
||||||
start: int,
|
|
||||||
stop: int,
|
|
||||||
image_reader,
|
|
||||||
text_reader: EmbeddingReader = None,
|
|
||||||
device: str = "cpu",
|
|
||||||
) -> None:
|
|
||||||
super(PriorEmbeddingLoader).__init__()
|
|
||||||
|
|
||||||
self.text_conditioned = text_conditioned
|
|
||||||
|
|
||||||
if not self.text_conditioned:
|
|
||||||
self.text_reader = text_reader
|
|
||||||
|
|
||||||
self.image_reader = image_reader
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.start = start
|
|
||||||
self.stop = stop
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
self.n = 0
|
|
||||||
loader_args = dict(
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
start=self.start,
|
|
||||||
end=self.stop,
|
|
||||||
show_progress=False,
|
|
||||||
)
|
|
||||||
if self.text_conditioned:
|
|
||||||
self.loader = self.image_reader(**loader_args)
|
|
||||||
else:
|
|
||||||
self.loader = zip(
|
|
||||||
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
try:
|
|
||||||
return self.get_sample()
|
|
||||||
except StopIteration:
|
|
||||||
raise StopIteration
|
|
||||||
|
|
||||||
def get_sample(self):
|
|
||||||
"""
|
|
||||||
pre-proocess data from either reader into a common format
|
|
||||||
"""
|
|
||||||
self.n += 1
|
|
||||||
|
|
||||||
if self.text_conditioned:
|
|
||||||
image_embedding, caption = next(self.loader)
|
|
||||||
|
|
||||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
|
||||||
tokenized_caption = tokenize(
|
|
||||||
caption["caption"].to_list(), truncate=True
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
return image_embedding, tokenized_caption
|
|
||||||
|
|
||||||
else:
|
|
||||||
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
|
||||||
|
|
||||||
image_embedding = from_numpy(image_embedding).to(self.device)
|
|
||||||
text_embedding = from_numpy(text_embedding).to(self.device)
|
|
||||||
|
|
||||||
return image_embedding, text_embedding
|
|
||||||
|
|
||||||
|
|
||||||
def make_splits(
|
|
||||||
text_conditioned: bool,
|
|
||||||
batch_size: int,
|
|
||||||
num_data_points: int,
|
|
||||||
train_split: float,
|
|
||||||
eval_split: float,
|
|
||||||
device: str,
|
|
||||||
img_url: str,
|
|
||||||
meta_url: str = None,
|
|
||||||
txt_url: str = None,
|
|
||||||
):
|
|
||||||
|
|
||||||
assert img_url is not None, "Must supply some image embeddings"
|
|
||||||
|
|
||||||
if text_conditioned:
|
|
||||||
assert meta_url is not None, "Must supply metadata url if text-conditioning"
|
|
||||||
image_reader = EmbeddingReader(
|
|
||||||
embeddings_folder=img_url,
|
|
||||||
file_format="parquet_npy",
|
|
||||||
meta_columns=["caption"],
|
|
||||||
metadata_folder=meta_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
# compute split points
|
|
||||||
if num_data_points > image_reader.count:
|
|
||||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
|
||||||
num_data_points = image_reader.count
|
|
||||||
|
|
||||||
train_set_size = int(train_split * num_data_points)
|
|
||||||
eval_set_size = int(eval_split * num_data_points)
|
|
||||||
eval_stop = int(train_set_size + eval_set_size)
|
|
||||||
|
|
||||||
train_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=0,
|
|
||||||
stop=train_set_size,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
eval_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=train_set_size,
|
|
||||||
stop=eval_stop,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
test_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=eval_stop,
|
|
||||||
stop=int(num_data_points),
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
txt_url is not None
|
|
||||||
), "Must supply text embedding url if not text-conditioning"
|
|
||||||
|
|
||||||
image_reader = EmbeddingReader(img_url, file_format="npy")
|
|
||||||
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
|
||||||
|
|
||||||
# compute split points
|
|
||||||
if num_data_points > image_reader.count:
|
|
||||||
print("Specified point count is larger than the number of points available...defaulting to max length of reader.")
|
|
||||||
num_data_points = image_reader.count
|
|
||||||
|
|
||||||
train_set_size = int(train_split * num_data_points)
|
|
||||||
eval_set_size = int(eval_split * num_data_points)
|
|
||||||
eval_stop = int(train_set_size + eval_set_size)
|
|
||||||
|
|
||||||
train_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
text_reader=text_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=0,
|
|
||||||
stop=train_set_size,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
eval_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
text_reader=text_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=train_set_size,
|
|
||||||
stop=eval_stop,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
test_loader = PriorEmbeddingLoader(
|
|
||||||
text_conditioned=text_conditioned,
|
|
||||||
image_reader=image_reader,
|
|
||||||
text_reader=text_reader,
|
|
||||||
batch_size=batch_size,
|
|
||||||
start=eval_stop,
|
|
||||||
stop=int(num_data_points),
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
return train_loader, eval_loader, test_loader
|
|
||||||
273
dalle2_pytorch/dataloaders/prior_loader.py
Normal file
273
dalle2_pytorch/dataloaders/prior_loader.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
from math import ceil
|
||||||
|
from clip import tokenize
|
||||||
|
from embedding_reader import EmbeddingReader
|
||||||
|
from torch import from_numpy
|
||||||
|
from torch.utils.data import IterableDataset, DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
class PriorEmbeddingDataset(IterableDataset):
|
||||||
|
"""
|
||||||
|
PriorEmbeddingDataset is a wrapper of EmbeddingReader.
|
||||||
|
|
||||||
|
It enables one to simplify the logic necessary to yield samples from
|
||||||
|
the different EmbeddingReader configurations available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_conditioned: bool,
|
||||||
|
batch_size: int,
|
||||||
|
start: int,
|
||||||
|
stop: int,
|
||||||
|
image_reader,
|
||||||
|
text_reader: EmbeddingReader = None,
|
||||||
|
) -> None:
|
||||||
|
super(PriorEmbeddingDataset).__init__()
|
||||||
|
|
||||||
|
self.text_conditioned = text_conditioned
|
||||||
|
|
||||||
|
if not self.text_conditioned:
|
||||||
|
self.text_reader = text_reader
|
||||||
|
|
||||||
|
self.image_reader = image_reader
|
||||||
|
self.start = start
|
||||||
|
self.stop = stop
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.stop - self.start
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
# D.R.Y loader args
|
||||||
|
loader_args = dict(
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
start=self.start,
|
||||||
|
end=self.stop,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# if the data requested is text conditioned, only load images
|
||||||
|
if self.text_conditioned:
|
||||||
|
self.loader = self.image_reader(**loader_args)
|
||||||
|
# otherwise, include text embeddings and bypass metadata
|
||||||
|
else:
|
||||||
|
self.loader = zip(
|
||||||
|
self.image_reader(**loader_args), self.text_reader(**loader_args)
|
||||||
|
)
|
||||||
|
|
||||||
|
# return the data loader in its formatted state
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
try:
|
||||||
|
return self.get_sample()
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
|
||||||
|
|
||||||
|
def get_sample(self):
|
||||||
|
"""
|
||||||
|
pre-proocess data from either reader into a common format
|
||||||
|
"""
|
||||||
|
if self.text_conditioned:
|
||||||
|
image_embedding, caption = next(self.loader)
|
||||||
|
|
||||||
|
image_embedding = from_numpy(image_embedding)
|
||||||
|
tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True)
|
||||||
|
|
||||||
|
return image_embedding, tokenized_caption
|
||||||
|
|
||||||
|
else:
|
||||||
|
(image_embedding, _), (text_embedding, _) = next(self.loader)
|
||||||
|
|
||||||
|
image_embedding = from_numpy(image_embedding)
|
||||||
|
text_embedding = from_numpy(text_embedding)
|
||||||
|
|
||||||
|
return image_embedding, text_embedding
|
||||||
|
|
||||||
|
|
||||||
|
# helper functions
|
||||||
|
|
||||||
|
|
||||||
|
def distribute_to_rank(start, stop, rank, world_size):
|
||||||
|
"""
|
||||||
|
Distribute data to each rank given the world size.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
- New start and stop points for this rank.
|
||||||
|
"""
|
||||||
|
num_samples = int(stop - start)
|
||||||
|
|
||||||
|
per_rank = int(ceil((num_samples) / float(world_size)))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
per_rank > 0
|
||||||
|
), f"Number of samples per rank must be larger than 0, (found: {per_rank})"
|
||||||
|
|
||||||
|
rank_start = start + rank * per_rank
|
||||||
|
|
||||||
|
rank_stop = min(rank_start + per_rank, stop)
|
||||||
|
|
||||||
|
new_length = rank_stop - rank_start
|
||||||
|
|
||||||
|
assert (
|
||||||
|
new_length > 0
|
||||||
|
), "Calculated start and stop points result in a length of zero for this rank."
|
||||||
|
|
||||||
|
return rank_start, rank_stop
|
||||||
|
|
||||||
|
|
||||||
|
def get_reader(
|
||||||
|
text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create an EmbeddingReader object from the specified URLs
|
||||||
|
|
||||||
|
get_reader() will always expect a url to image embeddings.
|
||||||
|
|
||||||
|
If text-conditioned, it will also expect a meta_url for the captions.
|
||||||
|
Otherwise, it will need txt_url for the matching text embeddings.
|
||||||
|
|
||||||
|
Returns an image_reader object if text-conditioned.
|
||||||
|
Otherwise it returns both an image_reader and a text_reader
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert img_url is not None, "Must supply a image url"
|
||||||
|
|
||||||
|
if text_conditioned:
|
||||||
|
assert meta_url is not None, "Must supply meta url if text-conditioned"
|
||||||
|
|
||||||
|
image_reader = EmbeddingReader(
|
||||||
|
embeddings_folder=img_url,
|
||||||
|
file_format="parquet_npy",
|
||||||
|
# will assume the caption column exists and is the only one requested
|
||||||
|
meta_columns=["caption"],
|
||||||
|
metadata_folder=meta_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_reader
|
||||||
|
|
||||||
|
# otherwise we will require text embeddings as well and return two readers
|
||||||
|
assert (
|
||||||
|
txt_url is not None
|
||||||
|
), "Must supply text embedding url if not text-conditioning"
|
||||||
|
|
||||||
|
image_reader = EmbeddingReader(img_url, file_format="npy")
|
||||||
|
text_reader = EmbeddingReader(txt_url, file_format="npy")
|
||||||
|
|
||||||
|
return image_reader, text_reader
|
||||||
|
|
||||||
|
|
||||||
|
def make_splits(
|
||||||
|
text_conditioned: bool,
|
||||||
|
batch_size: int,
|
||||||
|
num_data_points: int,
|
||||||
|
train_split: float,
|
||||||
|
eval_split: float,
|
||||||
|
image_reader: EmbeddingReader,
|
||||||
|
text_reader: EmbeddingReader = None,
|
||||||
|
start=0,
|
||||||
|
rank=0,
|
||||||
|
world_size=1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Split an embedding reader object as needed.
|
||||||
|
|
||||||
|
NOTE: make_splits() will infer the test set size from your train and eval.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
- text_conditioned: whether to prepare text-conditioned training data
|
||||||
|
- batch_size: the batch size for a single gpu
|
||||||
|
- num_data_points: the total number of data points you wish to train on
|
||||||
|
- train_split: the percentage of data you wish to train on
|
||||||
|
- eval_split: the percentage of data you wish to validate on
|
||||||
|
- image_reader: the image_reader you wish to split
|
||||||
|
- text_reader: the text_reader you want to split (if !text_conditioned)
|
||||||
|
- start: the starting point within your dataset
|
||||||
|
- rank: the rank of your worker
|
||||||
|
- world_size: the total world size of your distributed training run
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- PyTorch Dataloaders that yield tuples of (img, txt) data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert start < image_reader.count, "start position cannot exceed reader count."
|
||||||
|
|
||||||
|
# verify that the num_data_points does not exceed the max points
|
||||||
|
if num_data_points > (image_reader.count - start):
|
||||||
|
print(
|
||||||
|
"Specified count is larger than what's available...defaulting to reader's count."
|
||||||
|
)
|
||||||
|
num_data_points = image_reader.count
|
||||||
|
|
||||||
|
# compute split points
|
||||||
|
train_set_size = int(train_split * num_data_points)
|
||||||
|
eval_set_size = int(eval_split * num_data_points)
|
||||||
|
eval_start = train_set_size
|
||||||
|
eval_stop = int(eval_start + eval_set_size)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
train_split + eval_split
|
||||||
|
) < 1.0, "Specified train and eval split is too large to infer a test split."
|
||||||
|
|
||||||
|
# distribute to rank
|
||||||
|
rank_train_start, rank_train_stop = distribute_to_rank(
|
||||||
|
start, train_set_size, rank, world_size
|
||||||
|
)
|
||||||
|
rank_eval_start, rank_eval_stop = distribute_to_rank(
|
||||||
|
train_set_size, eval_stop, rank, world_size
|
||||||
|
)
|
||||||
|
rank_test_start, rank_test_stop = distribute_to_rank(
|
||||||
|
eval_stop, num_data_points, rank, world_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# wrap up splits into a dict
|
||||||
|
train_split_args = dict(
|
||||||
|
start=rank_train_start, stop=rank_train_stop, batch_size=batch_size
|
||||||
|
)
|
||||||
|
eval_split_args = dict(
|
||||||
|
start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size
|
||||||
|
)
|
||||||
|
test_split_args = dict(
|
||||||
|
start=rank_test_start, stop=rank_test_stop, batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if text_conditioned:
|
||||||
|
# add the text-conditioned args to a unified dict
|
||||||
|
reader_args = dict(
|
||||||
|
text_conditioned=text_conditioned,
|
||||||
|
image_reader=image_reader,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_split_args = dict(**reader_args, **train_split_args)
|
||||||
|
eval_split_args = dict(**reader_args, **eval_split_args)
|
||||||
|
test_split_args = dict(**reader_args, **test_split_args)
|
||||||
|
|
||||||
|
train = PriorEmbeddingDataset(**train_split_args)
|
||||||
|
val = PriorEmbeddingDataset(**eval_split_args)
|
||||||
|
test = PriorEmbeddingDataset(**test_split_args)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# add the non-conditioned args to a unified dict
|
||||||
|
reader_args = dict(
|
||||||
|
text_conditioned=text_conditioned,
|
||||||
|
image_reader=image_reader,
|
||||||
|
text_reader=text_reader,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_split_args = dict(**reader_args, **train_split_args)
|
||||||
|
eval_split_args = dict(**reader_args, **eval_split_args)
|
||||||
|
test_split_args = dict(**reader_args, **test_split_args)
|
||||||
|
|
||||||
|
train = PriorEmbeddingDataset(**train_split_args)
|
||||||
|
val = PriorEmbeddingDataset(**eval_split_args)
|
||||||
|
test = PriorEmbeddingDataset(**test_split_args)
|
||||||
|
|
||||||
|
# true batch size is specifed in the PriorEmbeddingDataset
|
||||||
|
train_loader = DataLoader(train, batch_size=None)
|
||||||
|
eval_loader = DataLoader(val, batch_size=None)
|
||||||
|
test_loader = DataLoader(test, batch_size=None)
|
||||||
|
|
||||||
|
return train_loader, eval_loader, test_loader
|
||||||
@@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
|
|||||||
|
|
||||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||||
from dalle2_pytorch.optimizer import get_optimizer
|
from dalle2_pytorch.optimizer import get_optimizer
|
||||||
|
from dalle2_pytorch.version import __version__
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -57,8 +59,7 @@ def num_to_groups(num, divisor):
|
|||||||
return arr
|
return arr
|
||||||
|
|
||||||
def get_pkg_version():
|
def get_pkg_version():
|
||||||
from pkg_resources import get_distribution
|
return __version__
|
||||||
return get_distribution('dalle2_pytorch').version
|
|
||||||
|
|
||||||
# decorators
|
# decorators
|
||||||
|
|
||||||
@@ -299,7 +300,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
scaler = self.scaler.state_dict(),
|
scaler = self.scaler.state_dict(),
|
||||||
optimizer = self.optimizer.state_dict(),
|
optimizer = self.optimizer.state_dict(),
|
||||||
model = self.diffusion_prior.state_dict(),
|
model = self.diffusion_prior.state_dict(),
|
||||||
version = get_pkg_version(),
|
version = __version__,
|
||||||
step = self.step.item(),
|
step = self.step.item(),
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
@@ -315,8 +316,8 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
loaded_obj = torch.load(str(path))
|
loaded_obj = torch.load(str(path))
|
||||||
|
|
||||||
if get_pkg_version() != loaded_obj['version']:
|
if version.parse(__version__) != loaded_obj['version']:
|
||||||
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
|
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
|
||||||
|
|
||||||
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
|
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
@@ -463,7 +464,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
save_obj = dict(
|
save_obj = dict(
|
||||||
model = self.decoder.state_dict(),
|
model = self.decoder.state_dict(),
|
||||||
version = get_pkg_version(),
|
version = __version__,
|
||||||
step = self.step.item(),
|
step = self.step.item(),
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
@@ -486,7 +487,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
loaded_obj = torch.load(str(path))
|
loaded_obj = torch.load(str(path))
|
||||||
|
|
||||||
if get_pkg_version() != loaded_obj['version']:
|
if version.parse(__version__) != loaded_obj['version']:
|
||||||
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
|
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
|
||||||
|
|
||||||
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
|
|||||||
1
dalle2_pytorch/version.py
Normal file
1
dalle2_pytorch/version.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__version__ = '0.6.2'
|
||||||
3
setup.py
3
setup.py
@@ -1,4 +1,5 @@
|
|||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
|
exec(open('dalle2_pytorch/version.py').read())
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name = 'dalle2-pytorch',
|
name = 'dalle2-pytorch',
|
||||||
@@ -10,7 +11,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.5.6',
|
version = __version__,
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
@@ -7,15 +7,13 @@ import torch
|
|||||||
import clip
|
import clip
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from dalle2_pytorch.dataloaders import make_splits
|
from dalle2_pytorch.dataloaders import make_splits, get_reader
|
||||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
|
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
|
||||||
|
|
||||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||||
|
|
||||||
from embedding_reader import EmbeddingReader
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
# constants
|
# constants
|
||||||
@@ -31,7 +29,7 @@ def exists(val):
|
|||||||
|
|
||||||
# functions
|
# functions
|
||||||
|
|
||||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
def eval_model(model, dataloader, text_conditioned, loss_type, device, phase="Validation",):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -39,6 +37,8 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
|
|||||||
total_samples = 0.
|
total_samples = 0.
|
||||||
|
|
||||||
for image_embeddings, text_data in tqdm(dataloader):
|
for image_embeddings, text_data in tqdm(dataloader):
|
||||||
|
image_embeddings = image_embeddings.to(device)
|
||||||
|
text_data = text_data.to(device)
|
||||||
|
|
||||||
batches = image_embeddings.shape[0]
|
batches = image_embeddings.shape[0]
|
||||||
|
|
||||||
@@ -57,12 +57,14 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
|
|||||||
|
|
||||||
tracker.log({f'{phase} {loss_type}': avg_loss})
|
tracker.log({f'{phase} {loss_type}': avg_loss})
|
||||||
|
|
||||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned):
|
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
|
||||||
diffusion_prior.eval()
|
diffusion_prior.eval()
|
||||||
|
|
||||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||||
|
|
||||||
for test_image_embeddings, text_data in tqdm(dataloader):
|
for test_image_embeddings, text_data in tqdm(dataloader):
|
||||||
|
test_image_embeddings = test_image_embeddings.to(device)
|
||||||
|
text_data = text_data.to(device)
|
||||||
|
|
||||||
# we are text conditioned, we produce an embedding from the tokenized text
|
# we are text conditioned, we produce an embedding from the tokenized text
|
||||||
if text_conditioned:
|
if text_conditioned:
|
||||||
@@ -296,15 +298,31 @@ def train(
|
|||||||
|
|
||||||
# Utilize wrapper to abstract away loader logic
|
# Utilize wrapper to abstract away loader logic
|
||||||
print_ribbon("Downloading Embeddings")
|
print_ribbon("Downloading Embeddings")
|
||||||
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points,
|
reader_args = dict(text_conditioned=dp_condition_on_text_encodings, img_url=image_embed_url)
|
||||||
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
|
|
||||||
|
|
||||||
if dp_condition_on_text_encodings:
|
if dp_condition_on_text_encodings:
|
||||||
loader_args = dict(**loader_args, meta_url=meta_url)
|
reader_args = dict(**reader_args, meta_url=meta_url)
|
||||||
|
img_reader = get_reader(**reader_args)
|
||||||
|
train_loader, eval_loader, test_loader = make_splits(
|
||||||
|
text_conditioned=dp_condition_on_text_encodings,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_data_points=num_data_points,
|
||||||
|
train_split=train_percent,
|
||||||
|
eval_split=val_percent,
|
||||||
|
image_reader=img_reader
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
loader_args = dict(**loader_args, txt_url=text_embed_url)
|
reader_args = dict(**reader_args, txt_url=text_embed_url)
|
||||||
|
img_reader, txt_reader = get_reader(**reader_args)
|
||||||
train_loader, eval_loader, test_loader = make_splits(**loader_args)
|
train_loader, eval_loader, test_loader = make_splits(
|
||||||
|
text_conditioned=dp_condition_on_text_encodings,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_data_points=num_data_points,
|
||||||
|
train_split=train_percent,
|
||||||
|
eval_split=val_percent,
|
||||||
|
image_reader=img_reader,
|
||||||
|
text_reader=txt_reader
|
||||||
|
)
|
||||||
|
|
||||||
### Training code ###
|
### Training code ###
|
||||||
|
|
||||||
@@ -315,9 +333,11 @@ def train(
|
|||||||
for _ in range(epochs):
|
for _ in range(epochs):
|
||||||
|
|
||||||
for image, text in tqdm(train_loader):
|
for image, text in tqdm(train_loader):
|
||||||
|
|
||||||
diffusion_prior.train()
|
diffusion_prior.train()
|
||||||
|
|
||||||
|
image = image.to(device)
|
||||||
|
text = text.to(device)
|
||||||
|
|
||||||
input_args = dict(image_embed=image)
|
input_args = dict(image_embed=image)
|
||||||
if dp_condition_on_text_encodings:
|
if dp_condition_on_text_encodings:
|
||||||
input_args = dict(**input_args, text = text)
|
input_args = dict(**input_args, text = text)
|
||||||
@@ -350,9 +370,9 @@ def train(
|
|||||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||||
# Get embeddings from the most recently saved model
|
# Get embeddings from the most recently saved model
|
||||||
if(step % REPORT_METRICS_EVERY) == 0:
|
if(step % REPORT_METRICS_EVERY) == 0:
|
||||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings)
|
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings, device=device)
|
||||||
### Evaluate model(validation run) ###
|
### Evaluate model(validation run) ###
|
||||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation")
|
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation", device=device)
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
trainer.update()
|
trainer.update()
|
||||||
|
|||||||
Reference in New Issue
Block a user