mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b9b4b9e5e | ||
|
|
44e09d5a4d | ||
|
|
34806663e3 | ||
|
|
dc816b1b6e | ||
|
|
05192ffac4 | ||
|
|
9440411954 | ||
|
|
981d407792 | ||
|
|
7c5477b26d | ||
|
|
be3bb868bf | ||
|
|
451de34871 | ||
|
|
f22e8c8741 | ||
|
|
87432e93ad | ||
|
|
d167378401 | ||
|
|
2d67d5821e | ||
|
|
748c7fe7af | ||
|
|
80046334ad | ||
|
|
36fb46a95e | ||
|
|
07abfcf45b | ||
|
|
2e35a9967d | ||
|
|
406e75043f | ||
|
|
9646dfc0e6 | ||
|
|
62043acb2f |
51
README.md
51
README.md
@@ -371,6 +371,7 @@ loss.backward()
|
||||
unet1 = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
text_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
@@ -395,7 +396,7 @@ decoder = Decoder(
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps
|
||||
@@ -626,6 +627,18 @@ images = dalle2(
|
||||
# save your image (in this example, of size 256x256)
|
||||
```
|
||||
|
||||
Alternatively, you can also use <a href="https://github.com/mlfoundations/open_clip">Open Clip</a>
|
||||
|
||||
```bash
|
||||
$ pip install open-clip-torch
|
||||
```
|
||||
|
||||
```python
|
||||
from dalle2_pytorch import OpenClipAdapter
|
||||
|
||||
clip = OpenClipAdapter()
|
||||
```
|
||||
|
||||
Now you'll just have to worry about training the Prior and the Decoder!
|
||||
|
||||
## Inpainting
|
||||
@@ -860,25 +873,23 @@ unet1 = Unet(
|
||||
text_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
cond_on_text_encodings = True,
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 16,
|
||||
image_embed_dim = 512,
|
||||
text_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_text_encodings = True
|
||||
).cuda()
|
||||
|
||||
decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 1000,
|
||||
condition_on_text_encodings = True
|
||||
timesteps = 1000
|
||||
).cuda()
|
||||
|
||||
decoder_trainer = DecoderTrainer(
|
||||
@@ -903,8 +914,8 @@ for unet_number in (1, 2):
|
||||
# after much training
|
||||
# you can sample from the exponentially moving averaged unets as so
|
||||
|
||||
mock_image_embed = torch.randn(4, 512).cuda()
|
||||
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||
mock_image_embed = torch.randn(32, 512).cuda()
|
||||
images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||
```
|
||||
|
||||
### Diffusion Prior Training
|
||||
@@ -1112,7 +1123,8 @@ For detailed information on training the diffusion prior, please refer to the [d
|
||||
- [x] allow for unet to be able to condition non-cross attention style as well
|
||||
- [x] speed up inference, read up on papers (ddim)
|
||||
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||
- [ ] try out the nested unet from https://arxiv.org/abs/2005.09007 after hearing several positive testimonies from researchers, for segmentation anyhow
|
||||
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
|
||||
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
|
||||
## Citations
|
||||
@@ -1241,4 +1253,25 @@ For detailed information on training the diffusion prior, please refer to the [d
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chen2022analog,
|
||||
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
|
||||
author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
|
||||
year = {2022},
|
||||
eprint = {2208.04202},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Qiao2019WeightS,
|
||||
title = {Weight Standardization},
|
||||
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
|
||||
journal = {ArXiv},
|
||||
year = {2019},
|
||||
volume = {abs/1903.10520}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,7 +9,7 @@ from collections.abc import Iterable
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
@@ -174,7 +174,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_prior,
|
||||
accelerator,
|
||||
accelerator = None,
|
||||
use_ema = True,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
@@ -186,8 +186,12 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||
assert isinstance(accelerator, Accelerator)
|
||||
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
|
||||
|
||||
if not exists(accelerator):
|
||||
accelerator = Accelerator(**accelerator_kwargs)
|
||||
|
||||
# assign some helpful member vars
|
||||
|
||||
@@ -300,7 +304,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
# all processes need to load checkpoint. no restriction here
|
||||
if isinstance(path_or_state, str):
|
||||
path = Path(path)
|
||||
path = Path(path_or_state)
|
||||
assert path.exists()
|
||||
loaded_obj = torch.load(str(path), map_location=self.device)
|
||||
|
||||
@@ -429,6 +433,7 @@ class DecoderTrainer(nn.Module):
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
warmup_steps = None,
|
||||
cosine_decay_max_steps = None,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
@@ -450,7 +455,7 @@ class DecoderTrainer(nn.Module):
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
|
||||
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
|
||||
|
||||
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
|
||||
@@ -458,7 +463,7 @@ class DecoderTrainer(nn.Module):
|
||||
schedulers = []
|
||||
warmup_schedulers = []
|
||||
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
|
||||
if isinstance(unet, nn.Identity):
|
||||
optimizers.append(None)
|
||||
schedulers.append(None)
|
||||
@@ -474,7 +479,11 @@ class DecoderTrainer(nn.Module):
|
||||
)
|
||||
|
||||
optimizers.append(optimizer)
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
if exists(unet_cosine_decay_max_steps):
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
|
||||
else:
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||
warmup_schedulers.append(warmup_scheduler)
|
||||
@@ -554,9 +563,15 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
optimizer_key = f'optim{ind}'
|
||||
scheduler_key = f'sched{ind}'
|
||||
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
state_dict = optimizer.state_dict() if optimizer is not None else None
|
||||
save_obj = {**save_obj, optimizer_key: state_dict}
|
||||
scheduler = getattr(self, scheduler_key)
|
||||
|
||||
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
|
||||
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
|
||||
|
||||
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||
@@ -577,10 +592,18 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
|
||||
scheduler_key = f'sched{ind}'
|
||||
scheduler = getattr(self, scheduler_key)
|
||||
|
||||
warmup_scheduler = self.warmup_schedulers[ind]
|
||||
if optimizer is not None:
|
||||
|
||||
if exists(optimizer):
|
||||
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if exists(scheduler):
|
||||
scheduler.load_state_dict(loaded_obj[scheduler_key])
|
||||
|
||||
if exists(warmup_scheduler):
|
||||
warmup_scheduler.last_step = last_step
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.0.3'
|
||||
__version__ = '1.8.0'
|
||||
|
||||
Reference in New Issue
Block a user