mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 19:44:26 +01:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3676ef4d49 | ||
|
|
28e944f328 | ||
|
|
14e63a3f67 | ||
|
|
09e9eaa5a6 | ||
|
|
e6d752cf4a | ||
|
|
ad20a14a4d | ||
|
|
0be1e0d64c | ||
|
|
98df1ba51e | ||
|
|
878b555ef7 | ||
|
|
63029f7388 | ||
|
|
c76a964fd6 | ||
|
|
79fabc4341 | ||
|
|
f7ef4bde38 | ||
|
|
93ba019069 | ||
|
|
8518684ae9 |
201
README.md
201
README.md
@@ -587,47 +587,6 @@ images = dalle2(
|
||||
|
||||
Now you'll just have to worry about training the Prior and the Decoder!
|
||||
|
||||
## Dataloaders
|
||||
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
||||
|
||||
### Decoder: Image Embedding Dataset
|
||||
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
|
||||
|
||||
Generating a dataset of this type:
|
||||
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
|
||||
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
|
||||
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||
|
||||
# Create a dataloader directly.
|
||||
dataloader = create_image_embedding_dataloader(
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
|
||||
shuffle_num=200, # Does a shuffle of the data with a buffer size of 200
|
||||
shuffle_shards=True, # Shuffle the order the shards are read in
|
||||
resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
|
||||
)
|
||||
for img, emb in dataloader:
|
||||
print(img.shape) # torch.Size([32, 3, 256, 256])
|
||||
print(emb.shape) # torch.Size([32, 512])
|
||||
# Train decoder only as shown above
|
||||
|
||||
# Or create a dataset without a loader so you can configure it manually
|
||||
dataset = ImageEmbeddingDataset(
|
||||
urls="/path/or/url/to/webdataset/{0000..9999}.tar",
|
||||
embedding_folder_url="path/or/url/to/embeddings/folder",
|
||||
shard_width=4,
|
||||
shuffle_shards=True,
|
||||
resample=False
|
||||
)
|
||||
```
|
||||
|
||||
## Experimental
|
||||
|
||||
### DALL-E2 with Latent Diffusion
|
||||
@@ -827,6 +786,149 @@ mock_image_embed = torch.randn(4, 512).cuda()
|
||||
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||
```
|
||||
|
||||
### Diffusion Prior Training
|
||||
|
||||
Similarly, one can use the `DiffusionPriorTrainer` to automatically instantiate and keep track of an exponential moving averaged prior.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, DiffusionPriorTrainer, Unet, Decoder, CLIP
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 6,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 6,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8
|
||||
).cuda()
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
|
||||
# prior networks (with transformer)
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
).cuda()
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
diffusion_prior_trainer = DiffusionPriorTrainer(
|
||||
diffusion_prior,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
ema_beta = 0.99,
|
||||
ema_update_after_step = 1000,
|
||||
ema_update_every = 10,
|
||||
)
|
||||
|
||||
loss = diffusion_prior_trainer(text, images)
|
||||
loss.backward()
|
||||
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
||||
|
||||
# after much of the above three lines in a loop
|
||||
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
|
||||
|
||||
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
|
||||
```
|
||||
|
||||
### Decoder Dataloaders
|
||||
|
||||
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
||||
|
||||
#### Decoder: Image Embedding Dataset
|
||||
|
||||
When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
|
||||
|
||||
Generating a dataset of this type:
|
||||
1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
|
||||
2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
|
||||
3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader
|
||||
|
||||
# Create a dataloader directly.
|
||||
dataloader = create_image_embedding_dataloader(
|
||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
|
||||
shuffle_num=200, # Does a shuffle of the data with a buffer size of 200
|
||||
shuffle_shards=True, # Shuffle the order the shards are read in
|
||||
resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
|
||||
)
|
||||
for img, emb in dataloader:
|
||||
print(img.shape) # torch.Size([32, 3, 256, 256])
|
||||
print(emb.shape) # torch.Size([32, 512])
|
||||
# Train decoder only as shown above
|
||||
|
||||
# Or create a dataset without a loader so you can configure it manually
|
||||
dataset = ImageEmbeddingDataset(
|
||||
urls="/path/or/url/to/webdataset/{0000..9999}.tar",
|
||||
embedding_folder_url="path/or/url/to/embeddings/folder",
|
||||
shard_width=4,
|
||||
shuffle_shards=True,
|
||||
resample=False
|
||||
)
|
||||
```
|
||||
|
||||
## Scripts
|
||||
|
||||
### Using the `train_diffusion_prior.py` script
|
||||
|
||||
This script allows training the DiffusionPrior on pre-computed text and image embeddings. The working example below elucidates this process.
|
||||
Please note that the script internally passes text_embed and image_embed to the DiffusionPrior, unlike the example below.
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
$ pyhon train_diffusion_prior.py
|
||||
```
|
||||
|
||||
The most significant parameters for the script are as follows:
|
||||
|
||||
--image-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
||||
|
||||
--text-embed-url, default = "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
||||
|
||||
--image-embed-dim, default=768 - 768 corresponds to the ViT iL/14 embedding size,change it to what your chosen ViT generates
|
||||
|
||||
--learning-rate, default=1.1e-4
|
||||
|
||||
--weight-decay, default=6.02e-2
|
||||
|
||||
--max-grad-norm, default=0.5
|
||||
|
||||
--batch-size, default=10 ** 4
|
||||
|
||||
--num-epochs, default=5
|
||||
|
||||
--clip, default=None # Signals the prior to use pre-computed embeddings
|
||||
|
||||
### Sample wandb run log
|
||||
|
||||
Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/aul0rhv5?workspace=
|
||||
|
||||
## CLI (wip)
|
||||
|
||||
```bash
|
||||
@@ -864,8 +966,8 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
|
||||
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
|
||||
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo)
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
|
||||
- [ ] train on a toy task, offer in colab
|
||||
@@ -877,7 +979,8 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] make sure resnet | convnext block hyperparameters can be configurable across unet depth (groups and expansion factor)
|
||||
- [ ] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -945,4 +1048,14 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Yu2022CoCaCC,
|
||||
title = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
|
||||
author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
|
||||
journal = {ArXiv},
|
||||
year = {2022},
|
||||
volume = {abs/2205.01917}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||
from dalle2_pytorch.train import DecoderTrainer
|
||||
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer
|
||||
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from x_clip import CLIP
|
||||
|
||||
@@ -23,9 +23,14 @@ from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
||||
|
||||
from resize_right import resize
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
|
||||
# use x-clip
|
||||
|
||||
from x_clip import CLIP
|
||||
from coca_pytorch import CoCa
|
||||
|
||||
# helper functions
|
||||
|
||||
@@ -113,9 +118,10 @@ EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 't
|
||||
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
|
||||
|
||||
class BaseClipAdapter(nn.Module):
|
||||
def __init__(self, clip):
|
||||
def __init__(self, clip, **kwargs):
|
||||
super().__init__()
|
||||
self.clip = clip
|
||||
self.overrides = kwargs
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
@@ -173,6 +179,39 @@ class XClipAdapter(BaseClipAdapter):
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return EmbeddedImage(l2norm(image_embed), image_encodings)
|
||||
|
||||
class CoCaAdapter(BaseClipAdapter):
|
||||
@property
|
||||
def dim_latent(self):
|
||||
return self.clip.dim
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
assert 'image_size' in self.overrides
|
||||
return self.overrides['image_size']
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
assert 'image_channels' in self.overrides
|
||||
return self.overrides['image_channels']
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
assert 'max_text_len' in self.overrides
|
||||
return self.overrides['max_text_len']
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
text_embed, text_encodings = self.clip.embed_text(text)
|
||||
return EmbeddedText(text_embed, text_encodings, text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image_embed, image_encodings = self.clip.embed_image(image)
|
||||
return EmbeddedImage(image_embed, image_encodings)
|
||||
|
||||
class OpenAIClipAdapter(BaseClipAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -225,7 +264,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
text_embed = self.clip.encode_text(text)
|
||||
text_encodings = self.text_encodings
|
||||
del self.text_encodings
|
||||
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
|
||||
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
@@ -233,7 +272,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.clip_normalize(unnormalize_img(image))
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return EmbeddedImage(image_embed.float(), None)
|
||||
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||
|
||||
# classifier free guidance functions
|
||||
|
||||
@@ -531,7 +570,8 @@ class Attention(nn.Module):
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
causal = False,
|
||||
post_norm = False
|
||||
post_norm = False,
|
||||
rotary_emb = None
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
@@ -547,6 +587,8 @@ class Attention(nn.Module):
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
LayerNorm(dim) if post_norm else nn.Identity()
|
||||
@@ -559,6 +601,12 @@ class Attention(nn.Module):
|
||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
|
||||
|
||||
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
||||
q = q * self.scale
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
if exists(self.rotary_emb):
|
||||
q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
|
||||
|
||||
# add null key / value for classifier free guidance in prior net
|
||||
|
||||
@@ -566,7 +614,7 @@ class Attention(nn.Module):
|
||||
k = torch.cat((nk, k), dim = -2)
|
||||
v = torch.cat((nv, v), dim = -2)
|
||||
|
||||
q = q * self.scale
|
||||
# calculate query / key similarities
|
||||
|
||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||
|
||||
@@ -616,15 +664,18 @@ class CausalTransformer(nn.Module):
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
final_proj = True,
|
||||
normformer = False
|
||||
normformer = False,
|
||||
rotary_emb = True
|
||||
):
|
||||
super().__init__()
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
|
||||
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||
]))
|
||||
|
||||
@@ -652,14 +703,12 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
num_timesteps = None,
|
||||
l2norm_output = False, # whether to restrict image embedding output with l2norm at the end (may make it easier to learn?)
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
self.l2norm_output = l2norm_output
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
@@ -738,8 +787,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
pred_image_embed = tokens[..., -1, :]
|
||||
|
||||
output_fn = l2norm if self.l2norm_output else identity
|
||||
return output_fn(pred_image_embed)
|
||||
return pred_image_embed
|
||||
|
||||
class DiffusionPrior(BaseGaussianDiffusion):
|
||||
def __init__(
|
||||
@@ -757,7 +805,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
sampling_clamp_l2norm = False,
|
||||
training_clamp_l2norm = False,
|
||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
clip_adapter_overrides = dict()
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -767,7 +817,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
if exists(clip):
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip)
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
clip = CoCaAdapter(clip, **clip_adapter_overrides)
|
||||
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
@@ -787,10 +839,11 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
self.predict_x_start = predict_x_start
|
||||
|
||||
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
self.image_embed_scale = default(image_embed_scale, image_embed_dim ** 0.5)
|
||||
self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
|
||||
|
||||
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||
self.training_clamp_l2norm = training_clamp_l2norm
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||
pred = self.net(x, t, **text_cond)
|
||||
@@ -843,11 +896,26 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
**text_cond
|
||||
)
|
||||
|
||||
if self.predict_x_start and self.training_clamp_l2norm:
|
||||
pred = l2norm(pred) * self.image_embed_scale
|
||||
|
||||
target = noise if not self.predict_x_start else image_embed
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
return loss
|
||||
|
||||
@torch.inference_mode()
|
||||
@eval_decorator
|
||||
def sample_batch_size(self, batch_size, text_cond):
|
||||
device = self.betas.device
|
||||
shape = (batch_size, self.image_embed_dim)
|
||||
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
||||
return img
|
||||
|
||||
@torch.inference_mode()
|
||||
@eval_decorator
|
||||
def sample(self, text, num_samples_per_batch = 2):
|
||||
@@ -1460,7 +1528,9 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self,
|
||||
unet,
|
||||
*,
|
||||
clip,
|
||||
clip = None,
|
||||
image_size = None,
|
||||
channels = 3,
|
||||
vae = tuple(),
|
||||
timesteps = 1000,
|
||||
image_cond_drop_prob = 0.1,
|
||||
@@ -1476,7 +1546,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||
clip_denoised = True,
|
||||
clip_x_start = True
|
||||
clip_x_start = True,
|
||||
clip_adapter_overrides = dict()
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -1484,15 +1555,24 @@ class Decoder(BaseGaussianDiffusion):
|
||||
loss_type = loss_type
|
||||
)
|
||||
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip)
|
||||
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
|
||||
freeze_model_and_make_eval_(clip)
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
self.clip = None
|
||||
if exists(clip):
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
clip = CoCaAdapter(clip, **clip_adapter_overrides)
|
||||
|
||||
self.clip = clip
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
freeze_model_and_make_eval_(clip)
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
|
||||
self.clip = clip
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
else:
|
||||
self.clip_image_size = image_size
|
||||
self.channels = channels
|
||||
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
@@ -1525,7 +1605,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
# unet image sizes
|
||||
|
||||
image_sizes = default(image_sizes, (clip.image_size,))
|
||||
image_sizes = default(image_sizes, (self.clip_image_size,))
|
||||
image_sizes = tuple(sorted(set(image_sizes)))
|
||||
|
||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||
@@ -1730,10 +1810,12 @@ class Decoder(BaseGaussianDiffusion):
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
|
||||
if not exists(image_embed):
|
||||
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
|
||||
text_encodings = text_mask = None
|
||||
if exists(text) and not exists(text_encodings):
|
||||
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
|
||||
# helper functions
|
||||
@@ -89,7 +89,83 @@ class EMA(nn.Module):
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.ema_model(*args, **kwargs)
|
||||
|
||||
# trainers
|
||||
# diffusion prior trainer
|
||||
|
||||
class DiffusionPriorTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_prior,
|
||||
use_ema = True,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
self.diffusion_prior = diffusion_prior
|
||||
|
||||
# exponential moving average
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
|
||||
|
||||
# optimizer and mixed precision stuff
|
||||
|
||||
self.amp = amp
|
||||
|
||||
self.scaler = GradScaler(enabled = amp)
|
||||
|
||||
self.optimizer = get_optimizer(
|
||||
diffusion_prior.parameters(),
|
||||
lr = lr,
|
||||
wd = wd,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# gradient clipping if needed
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
def update(self):
|
||||
if exists(self.max_grad_norm):
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
||||
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior.update()
|
||||
|
||||
@torch.inference_mode()
|
||||
def p_sample_loop(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample_batch_size(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*args,
|
||||
divisor = 1,
|
||||
**kwargs
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.diffusion_prior(*args, **kwargs)
|
||||
return self.scaler.scale(loss / divisor)
|
||||
|
||||
# decoder trainer
|
||||
|
||||
class DecoderTrainer(nn.Module):
|
||||
def __init__(
|
||||
|
||||
@@ -3,14 +3,15 @@ import copy
|
||||
from random import choice
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from PIL import Image
|
||||
from torchvision.datasets import ImageFolder
|
||||
import torchvision.transforms as T
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
|
||||
import torchvision.transforms as T
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.utils import make_grid, save_image
|
||||
|
||||
from einops import rearrange
|
||||
@@ -99,6 +100,7 @@ class VQGanVAETrainer(nn.Module):
|
||||
ema_update_after_step = 2000,
|
||||
ema_update_every = 10,
|
||||
apply_grad_penalty_every = 4,
|
||||
amp = False
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
|
||||
@@ -120,6 +122,10 @@ class VQGanVAETrainer(nn.Module):
|
||||
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
|
||||
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
|
||||
|
||||
self.amp = amp
|
||||
self.scaler = GradScaler(enabled = amp)
|
||||
self.discr_scaler = GradScaler(enabled = amp)
|
||||
|
||||
# create dataset
|
||||
|
||||
self.ds = ImageDataset(folder, image_size = image_size)
|
||||
@@ -178,20 +184,22 @@ class VQGanVAETrainer(nn.Module):
|
||||
img = next(self.dl)
|
||||
img = img.to(device)
|
||||
|
||||
loss = self.vae(
|
||||
img,
|
||||
return_loss = True,
|
||||
apply_grad_penalty = apply_grad_penalty
|
||||
)
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.vae(
|
||||
img,
|
||||
return_loss = True,
|
||||
apply_grad_penalty = apply_grad_penalty
|
||||
)
|
||||
|
||||
|
||||
self.scaler.scale(loss / self.grad_accum_every).backward()
|
||||
|
||||
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
|
||||
|
||||
(loss / self.grad_accum_every).backward()
|
||||
|
||||
self.optim.step()
|
||||
self.scaler.step(self.optim)
|
||||
self.scaler.update()
|
||||
self.optim.zero_grad()
|
||||
|
||||
|
||||
# update discriminator
|
||||
|
||||
if exists(self.vae.discr):
|
||||
@@ -200,12 +208,15 @@ class VQGanVAETrainer(nn.Module):
|
||||
img = next(self.dl)
|
||||
img = img.to(device)
|
||||
|
||||
loss = self.vae(img, return_discr_loss = True)
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.vae(img, return_discr_loss = True)
|
||||
|
||||
self.discr_scaler.scale(loss / self.grad_accum_every).backward()
|
||||
|
||||
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
|
||||
|
||||
(loss / self.grad_accum_every).backward()
|
||||
|
||||
self.discr_optim.step()
|
||||
self.discr_scaler.step(self.discr_optim)
|
||||
self.discr_scaler.update()
|
||||
self.discr_optim.zero_grad()
|
||||
|
||||
# log
|
||||
|
||||
4
setup.py
4
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.104',
|
||||
version = '0.1.5',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -24,12 +24,14 @@ setup(
|
||||
install_requires=[
|
||||
'click',
|
||||
'clip-anytorch',
|
||||
'coca-pytorch>=0.0.5',
|
||||
'einops>=0.4',
|
||||
'einops-exts>=0.0.3',
|
||||
'embedding-reader',
|
||||
'kornia>=0.5.4',
|
||||
'pillow',
|
||||
'resize-right>=0.0.2',
|
||||
'rotary-embedding-torch',
|
||||
'torch>=1.10',
|
||||
'torchvision',
|
||||
'tqdm',
|
||||
|
||||
@@ -85,7 +85,6 @@ def train(image_embed_dim,
|
||||
clip,
|
||||
dp_condition_on_text_encodings,
|
||||
dp_timesteps,
|
||||
dp_l2norm_output,
|
||||
dp_normformer,
|
||||
dp_cond_drop_prob,
|
||||
dpn_depth,
|
||||
@@ -105,8 +104,7 @@ def train(image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
heads = dpn_heads,
|
||||
normformer = dp_normformer,
|
||||
l2norm_output = dp_l2norm_output).to(device)
|
||||
normformer = dp_normformer).to(device)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
diffusion_prior = DiffusionPrior(
|
||||
@@ -273,7 +271,6 @@ def main():
|
||||
args.clip,
|
||||
args.dp_condition_on_text_encodings,
|
||||
args.dp_timesteps,
|
||||
args.dp_l2norm_output,
|
||||
args.dp_normformer,
|
||||
args.dp_cond_drop_prob,
|
||||
args.dpn_depth,
|
||||
|
||||
Reference in New Issue
Block a user