mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
54 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b2cf7b0bc | ||
|
|
984d62a373 | ||
|
|
683dd98b96 | ||
|
|
067ac323da | ||
|
|
91c8d1ca13 | ||
|
|
08238a7200 | ||
|
|
7166ad6711 | ||
|
|
fbba0f9aaf | ||
|
|
9f37705d87 | ||
|
|
c3df46e374 | ||
|
|
41fabf2922 | ||
|
|
5975e8222b | ||
|
|
c18c080128 | ||
|
|
b39653cf96 | ||
|
|
39f8b6cf16 | ||
|
|
d0c11b30b0 | ||
|
|
86e2d5ba84 | ||
|
|
0d82dff9c5 | ||
|
|
8bbc956ff1 | ||
|
|
22019fddeb | ||
|
|
6fb7e91343 | ||
|
|
ba58ae0bf2 | ||
|
|
1cc5d0afa7 | ||
|
|
59fa101c4d | ||
|
|
916ece164c | ||
|
|
cbaadb6931 | ||
|
|
083508ff8e | ||
|
|
7762edd0ff | ||
|
|
de5e628773 | ||
|
|
1b4046b039 | ||
|
|
27f19ba7fa | ||
|
|
8f38339c2b | ||
|
|
6b9b4b9e5e | ||
|
|
44e09d5a4d | ||
|
|
34806663e3 | ||
|
|
dc816b1b6e | ||
|
|
05192ffac4 | ||
|
|
9440411954 | ||
|
|
981d407792 | ||
|
|
7c5477b26d | ||
|
|
be3bb868bf | ||
|
|
451de34871 | ||
|
|
f22e8c8741 | ||
|
|
87432e93ad | ||
|
|
d167378401 | ||
|
|
2d67d5821e | ||
|
|
748c7fe7af | ||
|
|
80046334ad | ||
|
|
36fb46a95e | ||
|
|
07abfcf45b | ||
|
|
2e35a9967d | ||
|
|
406e75043f | ||
|
|
9646dfc0e6 | ||
|
|
62043acb2f |
87
README.md
87
README.md
@@ -49,6 +49,7 @@ This library would not have gotten to this working state without the help of
|
||||
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
|
||||
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
|
||||
- <a href="https://github.com/arogozhnikov">Alex</a> for <a href="https://github.com/arogozhnikov/einops">einops</a>, indispensable tool for tensor manipulation
|
||||
|
||||
... and many others. Thank you! 🙏
|
||||
|
||||
@@ -371,6 +372,7 @@ loss.backward()
|
||||
unet1 = Unet(
|
||||
dim = 128,
|
||||
image_embed_dim = 512,
|
||||
text_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
@@ -395,7 +397,7 @@ decoder = Decoder(
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||
loss.backward()
|
||||
|
||||
# do above for many steps
|
||||
@@ -626,6 +628,20 @@ images = dalle2(
|
||||
# save your image (in this example, of size 256x256)
|
||||
```
|
||||
|
||||
Alternatively, you can also use <a href="https://github.com/mlfoundations/open_clip">Open Clip</a>
|
||||
|
||||
```bash
|
||||
$ pip install open-clip-torch
|
||||
```
|
||||
|
||||
Ex. using the <a href="https://laion.ai/blog/large-openclip/">SOTA Open Clip</a> model trained by <a href="https://github.com/rom1504">Romain</a>
|
||||
|
||||
```python
|
||||
from dalle2_pytorch import OpenClipAdapter
|
||||
|
||||
clip = OpenClipAdapter('ViT-H/14')
|
||||
```
|
||||
|
||||
Now you'll just have to worry about training the Prior and the Decoder!
|
||||
|
||||
## Inpainting
|
||||
@@ -860,25 +876,23 @@ unet1 = Unet(
|
||||
text_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults=(1, 2, 4, 8)
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
cond_on_text_encodings = True,
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 16,
|
||||
image_embed_dim = 512,
|
||||
text_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_text_encodings = True
|
||||
).cuda()
|
||||
|
||||
decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 1000,
|
||||
condition_on_text_encodings = True
|
||||
timesteps = 1000
|
||||
).cuda()
|
||||
|
||||
decoder_trainer = DecoderTrainer(
|
||||
@@ -903,8 +917,8 @@ for unet_number in (1, 2):
|
||||
# after much training
|
||||
# you can sample from the exponentially moving averaged unets as so
|
||||
|
||||
mock_image_embed = torch.randn(4, 512).cuda()
|
||||
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||
mock_image_embed = torch.randn(32, 512).cuda()
|
||||
images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||
```
|
||||
|
||||
### Diffusion Prior Training
|
||||
@@ -1054,7 +1068,7 @@ dataloader = create_image_embedding_dataloader(
|
||||
)
|
||||
for img, emb in dataloader:
|
||||
print(img.shape) # torch.Size([32, 3, 256, 256])
|
||||
print(emb.shape) # torch.Size([32, 512])
|
||||
print(emb["img"].shape) # torch.Size([32, 512])
|
||||
# Train decoder only as shown above
|
||||
|
||||
# Or create a dataset without a loader so you can configure it manually
|
||||
@@ -1112,7 +1126,9 @@ For detailed information on training the diffusion prior, please refer to the [d
|
||||
- [x] allow for unet to be able to condition non-cross attention style as well
|
||||
- [x] speed up inference, read up on papers (ddim)
|
||||
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||
- [ ] try out the nested unet from https://arxiv.org/abs/2005.09007 after hearing several positive testimonies from researchers, for segmentation anyhow
|
||||
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
|
||||
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
|
||||
- [ ] add simple outpainting, text-guided 2x size the image for starters
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
|
||||
## Citations
|
||||
@@ -1241,4 +1257,55 @@ For detailed information on training the diffusion prior, please refer to the [d
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{chen2022analog,
|
||||
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
|
||||
author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
|
||||
year = {2022},
|
||||
eprint = {2208.04202},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Qiao2019WeightS,
|
||||
title = {Weight Standardization},
|
||||
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
|
||||
journal = {ArXiv},
|
||||
year = {2019},
|
||||
volume = {abs/1903.10520}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{rogozhnikov2022einops,
|
||||
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
|
||||
author = {Alex Rogozhnikov},
|
||||
booktitle = {International Conference on Learning Representations},
|
||||
year = {2022},
|
||||
url = {https://openreview.net/forum?id=oapKSVM2bcj}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Sunkara2022NoMS,
|
||||
title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
|
||||
author = {Raja Sunkara and Tie Luo},
|
||||
journal = {ArXiv},
|
||||
year = {2022},
|
||||
volume = {abs/2208.03641}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Salimans2022ProgressiveDF,
|
||||
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
|
||||
author = {Tim Salimans and Jonathan Ho},
|
||||
journal = {ArXiv},
|
||||
year = {2022},
|
||||
volume = {abs/2202.00512}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dalle2_pytorch.version import __version__
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,11 +4,13 @@ from pydantic import BaseModel, validator, root_validator
|
||||
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
|
||||
|
||||
from x_clip import CLIP as XCLIP
|
||||
from open_clip import list_pretrained
|
||||
from coca_pytorch import CoCa
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import (
|
||||
CoCaAdapter,
|
||||
OpenAIClipAdapter,
|
||||
OpenClipAdapter,
|
||||
Unet,
|
||||
Decoder,
|
||||
DiffusionPrior,
|
||||
@@ -117,6 +119,10 @@ class AdapterConfig(BaseModel):
|
||||
def create(self):
|
||||
if self.make == "openai":
|
||||
return OpenAIClipAdapter(self.model)
|
||||
elif self.make == "open_clip":
|
||||
pretrained = dict(list_pretrained())
|
||||
checkpoint = pretrained[self.model]
|
||||
return OpenClipAdapter(name=self.model, pretrained=checkpoint)
|
||||
elif self.make == "x-clip":
|
||||
return XClipAdapter(XCLIP(**self.base_model_kwargs))
|
||||
elif self.make == "coca":
|
||||
@@ -241,7 +247,7 @@ class DecoderConfig(BaseModel):
|
||||
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
|
||||
channels: int = 3
|
||||
timesteps: int = 1000
|
||||
sample_timesteps: Optional[SingularOrIterable[int]] = None
|
||||
sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None
|
||||
loss_type: str = 'l2'
|
||||
beta_schedule: ListOrTuple[str] = None # None means all cosine
|
||||
learned_variance: SingularOrIterable[bool] = True
|
||||
@@ -307,6 +313,7 @@ class DecoderTrainConfig(BaseModel):
|
||||
wd: SingularOrIterable[float] = 0.01
|
||||
warmup_steps: Optional[SingularOrIterable[int]] = None
|
||||
find_unused_parameters: bool = True
|
||||
static_graph: bool = True
|
||||
max_grad_norm: SingularOrIterable[float] = 0.5
|
||||
save_every_n_samples: int = 100000
|
||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||
|
||||
@@ -9,7 +9,7 @@ from collections.abc import Iterable
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
@@ -174,20 +174,25 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_prior,
|
||||
accelerator,
|
||||
accelerator = None,
|
||||
use_ema = True,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-6,
|
||||
max_grad_norm = None,
|
||||
group_wd_params = True,
|
||||
warmup_steps = 1,
|
||||
warmup_steps = None,
|
||||
cosine_decay_max_steps = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||
assert isinstance(accelerator, Accelerator)
|
||||
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
|
||||
|
||||
if not exists(accelerator):
|
||||
accelerator = Accelerator(**accelerator_kwargs)
|
||||
|
||||
# assign some helpful member vars
|
||||
|
||||
@@ -229,8 +234,11 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
**self.optim_kwargs,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
||||
|
||||
if exists(cosine_decay_max_steps):
|
||||
self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps)
|
||||
else:
|
||||
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
||||
|
||||
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
|
||||
|
||||
@@ -267,6 +275,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
# FIXME: LambdaLR can't be saved due to pickling issues
|
||||
save_obj = dict(
|
||||
optimizer = self.optimizer.state_dict(),
|
||||
scheduler = self.scheduler.state_dict(),
|
||||
warmup_scheduler = self.warmup_scheduler,
|
||||
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
|
||||
version = version.parse(__version__),
|
||||
@@ -300,7 +309,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
# all processes need to load checkpoint. no restriction here
|
||||
if isinstance(path_or_state, str):
|
||||
path = Path(path)
|
||||
path = Path(path_or_state)
|
||||
assert path.exists()
|
||||
loaded_obj = torch.load(str(path), map_location=self.device)
|
||||
|
||||
@@ -313,7 +322,9 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
# unwrap the model when loading from checkpoint
|
||||
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
|
||||
|
||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
self.scheduler.load_state_dict(loaded_obj['scheduler'])
|
||||
|
||||
# set warmupstep
|
||||
if exists(self.warmup_scheduler):
|
||||
@@ -346,7 +357,8 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
||||
if not self.accelerator.optimizer_step_was_skipped:
|
||||
with self.warmup_scheduler.dampening():
|
||||
sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
|
||||
with sched_context():
|
||||
self.scheduler.step()
|
||||
|
||||
if self.use_ema:
|
||||
@@ -429,6 +441,7 @@ class DecoderTrainer(nn.Module):
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
warmup_steps = None,
|
||||
cosine_decay_max_steps = None,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
@@ -450,7 +463,7 @@ class DecoderTrainer(nn.Module):
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
|
||||
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
|
||||
|
||||
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
|
||||
@@ -458,7 +471,7 @@ class DecoderTrainer(nn.Module):
|
||||
schedulers = []
|
||||
warmup_schedulers = []
|
||||
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
|
||||
if isinstance(unet, nn.Identity):
|
||||
optimizers.append(None)
|
||||
schedulers.append(None)
|
||||
@@ -474,7 +487,11 @@ class DecoderTrainer(nn.Module):
|
||||
)
|
||||
|
||||
optimizers.append(optimizer)
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
if exists(unet_cosine_decay_max_steps):
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
|
||||
else:
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||
warmup_schedulers.append(warmup_scheduler)
|
||||
@@ -554,9 +571,15 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
optimizer_key = f'optim{ind}'
|
||||
scheduler_key = f'sched{ind}'
|
||||
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
state_dict = optimizer.state_dict() if optimizer is not None else None
|
||||
save_obj = {**save_obj, optimizer_key: state_dict}
|
||||
scheduler = getattr(self, scheduler_key)
|
||||
|
||||
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
|
||||
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
|
||||
|
||||
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||
@@ -577,10 +600,18 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
|
||||
scheduler_key = f'sched{ind}'
|
||||
scheduler = getattr(self, scheduler_key)
|
||||
|
||||
warmup_scheduler = self.warmup_schedulers[ind]
|
||||
if optimizer is not None:
|
||||
|
||||
if exists(optimizer):
|
||||
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if exists(scheduler):
|
||||
scheduler.load_state_dict(loaded_obj[scheduler_key])
|
||||
|
||||
if exists(warmup_scheduler):
|
||||
warmup_scheduler.last_step = last_step
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.0.3'
|
||||
__version__ = '1.12.1'
|
||||
|
||||
3
setup.py
3
setup.py
@@ -26,7 +26,8 @@ setup(
|
||||
install_requires=[
|
||||
'accelerate',
|
||||
'click',
|
||||
'clip-anytorch',
|
||||
'open-clip-torch>=2.0.0,<3.0.0',
|
||||
'clip-anytorch>=2.4.0',
|
||||
'coca-pytorch>=0.0.5',
|
||||
'ema-pytorch>=0.0.7',
|
||||
'einops>=0.4',
|
||||
|
||||
@@ -134,7 +134,7 @@ def get_example_data(dataloader, device, n=5):
|
||||
break
|
||||
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
||||
|
||||
def generate_samples(trainer, example_data, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
|
||||
def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
|
||||
"""
|
||||
Takes example data and generates images from the embeddings
|
||||
Returns three lists: real images, generated images, and captions
|
||||
@@ -144,7 +144,9 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
|
||||
if img_embeddings[0] is None:
|
||||
# Generate image embeddings from clip
|
||||
imgs_tensor = torch.stack(real_images)
|
||||
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
|
||||
assert clip is not None, "clip is None, but img_embeddings is None"
|
||||
imgs_tensor.to(device=device)
|
||||
img_embeddings, img_encoding = clip.embed_image(imgs_tensor)
|
||||
sample_params["image_embed"] = img_embeddings
|
||||
else:
|
||||
# Then we are using precomputed image embeddings
|
||||
@@ -153,8 +155,10 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
|
||||
if condition_on_text_encodings:
|
||||
if text_embeddings[0] is None:
|
||||
# Generate text embeddings from text
|
||||
tokenized_texts = tokenize(txts, truncate=True)
|
||||
sample_params["text"] = tokenized_texts
|
||||
assert clip is not None, "clip is None, but text_embeddings is None"
|
||||
tokenized_texts = tokenize(txts, truncate=True).to(device=device)
|
||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||
sample_params["text_encodings"] = text_encodings
|
||||
else:
|
||||
# Then we are using precomputed text embeddings
|
||||
text_embeddings = torch.stack(text_embeddings)
|
||||
@@ -166,7 +170,7 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
|
||||
sample_params["image"] = torch.stack(real_images)
|
||||
if device is not None:
|
||||
sample_params["_device"] = device
|
||||
samples = trainer.sample(**sample_params)
|
||||
samples = trainer.sample(**sample_params, _cast_deepspeed_precision=False) # At sampling time we don't want to cast to FP16
|
||||
generated_images = list(samples)
|
||||
captions = [text_prepend + txt for txt in txts]
|
||||
if match_image_size:
|
||||
@@ -174,15 +178,15 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
|
||||
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
|
||||
return real_images, generated_images, captions
|
||||
|
||||
def generate_grid_samples(trainer, examples, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
|
||||
def generate_grid_samples(trainer, examples, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
|
||||
"""
|
||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||
"""
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
|
||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||
return grid_images, captions
|
||||
|
||||
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=None, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
"""
|
||||
Computes evaluation metrics for the decoder
|
||||
"""
|
||||
@@ -192,7 +196,7 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, conditi
|
||||
if len(examples) == 0:
|
||||
print("No data to evaluate. Check that your dataloader has shards.")
|
||||
return metrics
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
|
||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||
@@ -225,8 +229,8 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, conditi
|
||||
metrics["KID_std"] = kid_std.item()
|
||||
if exists(LPIPS):
|
||||
# Convert from [0, 1] to [-1, 1]
|
||||
renorm_real_images = real_images.mul(2).sub(1)
|
||||
renorm_generated_images = generated_images.mul(2).sub(1)
|
||||
renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1)
|
||||
renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1)
|
||||
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
|
||||
lpips.to(device=device)
|
||||
lpips.update(renorm_real_images, renorm_generated_images)
|
||||
@@ -265,6 +269,7 @@ def train(
|
||||
accelerator: Accelerator,
|
||||
tracker: Tracker,
|
||||
inference_device,
|
||||
clip=None,
|
||||
evaluate_config=None,
|
||||
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
||||
validation_samples = None,
|
||||
@@ -371,15 +376,19 @@ def train(
|
||||
forward_params['image_embed'] = img_emb
|
||||
else:
|
||||
# Forward pass automatically generates embedding
|
||||
pass
|
||||
assert clip is not None
|
||||
img_embed, img_encoding = clip.embed_image(img)
|
||||
forward_params['image_embed'] = img_embed
|
||||
if condition_on_text_encodings:
|
||||
if has_text_embedding:
|
||||
forward_params['text_encodings'] = text_emb
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
assert clip is not None
|
||||
tokenized_texts = tokenize(txt, truncate=True).to(inference_device)
|
||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||
forward_params['text'] = tokenized_texts
|
||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||
forward_params['text_encodings'] = text_encodings
|
||||
loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)
|
||||
trainer.update(unet_number=unet)
|
||||
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
||||
@@ -419,7 +428,7 @@ def train(
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
|
||||
if exists(n_sample_images) and n_sample_images > 0:
|
||||
trainer.eval()
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
if epoch_samples is not None and sample >= epoch_samples:
|
||||
@@ -462,15 +471,19 @@ def train(
|
||||
forward_params['image_embed'] = img_emb.float()
|
||||
else:
|
||||
# Forward pass automatically generates embedding
|
||||
pass
|
||||
assert clip is not None
|
||||
img_embed, img_encoding = clip.embed_image(img)
|
||||
forward_params['image_embed'] = img_embed
|
||||
if condition_on_text_encodings:
|
||||
if has_text_embedding:
|
||||
forward_params['text_encodings'] = text_emb.float()
|
||||
else:
|
||||
# Then we need to pass the text instead
|
||||
tokenized_texts = tokenize(txt, truncate=True)
|
||||
assert clip is not None
|
||||
tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device)
|
||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||
forward_params['text'] = tokenized_texts
|
||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||
forward_params['text_encodings'] = text_encodings
|
||||
loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)
|
||||
average_val_loss_tensor[0, unet-1] += loss
|
||||
|
||||
@@ -498,7 +511,7 @@ def train(
|
||||
if next_task == 'eval':
|
||||
if exists(evaluate_config):
|
||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, clip=clip, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
|
||||
if is_master:
|
||||
tracker.log(evaluation, step=step())
|
||||
next_task = 'sample'
|
||||
@@ -509,8 +522,8 @@ def train(
|
||||
# Generate examples and save the model if we are the master
|
||||
# Generate sample images
|
||||
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||
test_images, test_captions = generate_grid_samples(trainer, test_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
|
||||
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
@@ -532,6 +545,7 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
|
||||
"NumProcesses": accelerator.num_processes,
|
||||
"MixedPrecision": accelerator.mixed_precision
|
||||
}
|
||||
accelerator.wait_for_everyone() # If nodes arrive at this point at different times they might try to autoresume the current run which makes no sense and will cause errors
|
||||
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
|
||||
tracker.save_config(config_path, config_name='decoder_config.json')
|
||||
tracker.add_save_metadata(state_dict_key='config', metadata=config.dict())
|
||||
@@ -542,7 +556,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
torch.manual_seed(config.seed)
|
||||
|
||||
# Set up accelerator for configurable distributed training
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph)
|
||||
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
|
||||
|
||||
@@ -555,10 +569,6 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
|
||||
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
|
||||
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
|
||||
|
||||
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
|
||||
# This is an invalid configuration until we figure out how to handle this
|
||||
raise ValueError("DeepSpeed does not support multi-node distributed training")
|
||||
|
||||
# Set up data
|
||||
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
||||
@@ -579,6 +589,11 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
seed = config.seed,
|
||||
)
|
||||
|
||||
# If clip is in the model, we need to remove it for compatibility with deepspeed
|
||||
clip = None
|
||||
if config.decoder.clip is not None:
|
||||
clip = config.decoder.clip.create() # Of course we keep it to use it during training, just not in the decoder as that causes issues
|
||||
config.decoder.clip = None
|
||||
# Create the decoder model and print basic info
|
||||
decoder = config.decoder.create()
|
||||
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
|
||||
@@ -590,7 +605,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
|
||||
|
||||
has_clip_model = config.decoder.clip is not None
|
||||
has_clip_model = clip is not None
|
||||
data_source_string = ""
|
||||
|
||||
if has_img_embeddings:
|
||||
@@ -615,6 +630,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
|
||||
|
||||
train(dataloaders, decoder, accelerator,
|
||||
clip=clip,
|
||||
tracker=tracker,
|
||||
inference_device=accelerator.device,
|
||||
evaluate_config=config.evaluate,
|
||||
|
||||
Reference in New Issue
Block a user