Compare commits

...

46 Commits
1.8.0 ... main

Author SHA1 Message Date
lucidrains
680dfc4d93 yet more pydantic v2 stuff 2023-10-19 07:40:57 -07:00
lucidrains
b6fecae91a fix another pydantic 2 migration error 2023-10-18 21:07:47 -07:00
lucidrains
dab2f74650 fix self_attn type on unetconfig 2023-10-18 21:02:50 -07:00
lucidrains
1e173f4c66 more fixes to config 2023-10-18 20:27:32 -07:00
lucidrains
410a6144e1 new einops is torch compile friendly 2023-10-18 15:45:09 -07:00
lucidrains
c6c3882dc1 fix all optional types in train config 2023-10-07 11:34:34 -07:00
Phil Wang
512b52bd78 1.15.2 2023-10-04 09:38:46 -07:00
Neil Kim Nielsen
147c156c8a Make TrackerLoadConfig optional (#306) 2023-10-04 09:38:30 -07:00
Phil Wang
40843bcc21 pydantic 2 2023-07-15 09:32:44 -07:00
Phil Wang
00e07b7d61 force einops 0.6.1 or greater and call allow_ops_in_compiled_graph 2023-04-20 14:08:52 -07:00
Phil Wang
0069857cf8 remove einops exts for better pytorch 2.0 compile compatibility 2023-04-20 07:05:29 -07:00
Phil Wang
580274be79 use .to(device) to avoid copy, within one_unet_in_gpu context 2023-03-07 12:41:55 -08:00
Phil Wang
848e8a480a always rederive the predicted noise from the clipped x0 for ddim + predict noise objective 2023-03-05 10:45:44 -08:00
Phil Wang
cc58f75474 bump to newer package of clip-anytorch that allows for text encodings < maximum context length 2023-03-04 09:37:25 -08:00
Phil Wang
3b2cf7b0bc fix for self conditioning in diffusion prior network https://github.com/lucidrains/DALLE2-pytorch/issues/273 2023-02-11 17:18:40 -08:00
Phil Wang
984d62a373 default ddim sampling eta to 0 2022-12-23 13:23:09 -08:00
Phil Wang
683dd98b96 extra insurance in case eos id is not there 2022-12-15 10:54:21 -08:00
Phil Wang
067ac323da address https://github.com/lucidrains/DALLE2-pytorch/issues/266 2022-11-23 08:41:25 -08:00
zion
91c8d1ca13 bug fix cosine annealing optimizer in prior trainer (#262) 2022-11-11 12:15:13 -08:00
zion
08238a7200 depend on open-clip-torch (#261)
fix the previous commit which assumes open_clip is installed
2022-11-07 16:19:08 -08:00
zion
7166ad6711 add open clip to train_config (#260)
add the ability to use open_clip in the train configs (useful for the new SOTA h/14 model)
2022-11-07 15:44:36 -08:00
Phil Wang
fbba0f9aaf bring in prediction of v objective, combining the findings from progressive distillation paper and imagen-video to the eventual extension of dalle2 to make-a-video 2022-10-28 18:21:07 -07:00
Romain Beaumont
9f37705d87 Add static graph param (#226)
* Add static graph param

* use static graph param
2022-10-25 19:31:29 +02:00
Phil Wang
c3df46e374 fix openclipadapter to be able to use latest open sourced sota model 2022-10-23 15:12:09 -07:00
Phil Wang
41fabf2922 fix a dtype conversion issue for the diffusion timesteps in the diffusion prior, thanks to @JiaHeng-DLUT 2022-10-19 09:26:06 -07:00
Heng Jia
5975e8222b Fix assert message (#253) 2022-10-18 08:50:59 -07:00
Phil Wang
c18c080128 fix for use with larger openai clip models by extracting dimension of last layernorm in clip 2022-09-29 09:09:47 -07:00
Phil Wang
b39653cf96 fix readme dataloader example 2022-09-20 08:39:52 -07:00
Phil Wang
39f8b6cf16 show example of using SOTA open sourced open clip 2022-09-19 10:45:20 -07:00
Phil Wang
d0c11b30b0 handle open clip adapter image size being a tuple 2022-09-19 10:27:14 -07:00
zion
86e2d5ba84 Minor Decoder Train Script Fixes (#242)
* ensure tokenized text is on proper device
* fix lpips mage distribution
2022-09-15 17:21:48 -07:00
Phil Wang
0d82dff9c5 in ddim, noise should be predicted after x0 is maybe clipped, thanks to @lukovnikov for pointing this out in another repository 2022-09-01 09:40:47 -07:00
Phil Wang
8bbc956ff1 fix bug with misnamed variable in diffusion prior network 2022-08-31 17:19:05 -07:00
Phil Wang
22019fddeb todo 2022-08-31 13:36:05 -07:00
Phil Wang
6fb7e91343 fix ddim to use alpha_cumprod 2022-08-31 07:40:46 -07:00
Phil Wang
ba58ae0bf2 add two asserts to diffusion prior to ensure matching image embedding dimensions for clip, diffusion prior network, and what was set on diffusion prior 2022-08-28 10:11:37 -07:00
Phil Wang
1cc5d0afa7 upgrade to best downsample 2022-08-25 10:37:02 -07:00
Phil Wang
59fa101c4d fix classifier free guidance for diffusion prior, thanks to @jaykim9870 for spotting the issue 2022-08-23 08:29:01 -07:00
Aidan Dempster
916ece164c Merge pull request #234 from Veldrovive/deepspeed_fp16
Fixed issues with clip and deepspeed fp16
2022-08-20 19:01:43 -04:00
Aidan
cbaadb6931 Fixed issues with clip and deepspeed fp16
Also more more general compatibility fixes
2022-08-20 17:58:32 +00:00
Phil Wang
083508ff8e cast attention matrix back to original dtype pre-softmax in attention 2022-08-20 10:56:01 -07:00
Phil Wang
7762edd0ff make it work for @ethancohen123 2022-08-19 11:28:58 -07:00
Phil Wang
de5e628773 cite einops 2022-08-17 08:58:41 -07:00
Phil Wang
1b4046b039 gratitude 2022-08-17 08:57:33 -07:00
Phil Wang
27f19ba7fa make sure diffusion prior trainer can operate with no warmup 2022-08-15 14:27:40 -07:00
Phil Wang
8f38339c2b give diffusion prior trainer cosine annealing lr too 2022-08-15 07:38:01 -07:00
10 changed files with 364 additions and 149 deletions

View File

@@ -49,6 +49,7 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice - <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship - <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library - <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
- <a href="https://github.com/arogozhnikov">Alex</a> for <a href="https://github.com/arogozhnikov/einops">einops</a>, indispensable tool for tensor manipulation
... and many others. Thank you! 🙏 ... and many others. Thank you! 🙏
@@ -633,10 +634,12 @@ Alternatively, you can also use <a href="https://github.com/mlfoundations/open_c
$ pip install open-clip-torch $ pip install open-clip-torch
``` ```
Ex. using the <a href="https://laion.ai/blog/large-openclip/">SOTA Open Clip</a> model trained by <a href="https://github.com/rom1504">Romain</a>
```python ```python
from dalle2_pytorch import OpenClipAdapter from dalle2_pytorch import OpenClipAdapter
clip = OpenClipAdapter() clip = OpenClipAdapter('ViT-H/14')
``` ```
Now you'll just have to worry about training the Prior and the Decoder! Now you'll just have to worry about training the Prior and the Decoder!
@@ -1065,7 +1068,7 @@ dataloader = create_image_embedding_dataloader(
) )
for img, emb in dataloader: for img, emb in dataloader:
print(img.shape) # torch.Size([32, 3, 256, 256]) print(img.shape) # torch.Size([32, 3, 256, 256])
print(emb.shape) # torch.Size([32, 512]) print(emb["img"].shape) # torch.Size([32, 512])
# Train decoder only as shown above # Train decoder only as shown above
# Or create a dataset without a loader so you can configure it manually # Or create a dataset without a loader so you can configure it manually
@@ -1125,6 +1128,7 @@ For detailed information on training the diffusion prior, please refer to the [d
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865 - [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments - [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364 - [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
- [ ] add simple outpainting, text-guided 2x size the image for starters
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
## Citations ## Citations
@@ -1274,4 +1278,34 @@ For detailed information on training the diffusion prior, please refer to the [d
} }
``` ```
```bibtex
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
```
```bibtex
@article{Sunkara2022NoMS,
title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
author = {Raja Sunkara and Tie Luo},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.03641}
}
```
```bibtex
@article{Salimans2022ProgressiveDF,
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
author = {Tim Salimans and Jonathan Ho},
journal = {ArXiv},
year = {2022},
volume = {abs/2202.00512}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -9,7 +9,7 @@
"dim_mults": [1, 2, 4, 8], "dim_mults": [1, 2, 4, 8],
"attn_dim_head": 16, "attn_dim_head": 16,
"attn_heads": 4, "attn_heads": 4,
"self_attn": [false, true, true, true] "self_attn": [false, true, true, true]
} }
], ],
"clip": { "clip": {

View File

@@ -1,6 +1,6 @@
from dalle2_pytorch.version import __version__ from dalle2_pytorch.version import __version__
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE from dalle2_pytorch.vqgan_vae import VQGanVAE

View File

@@ -12,10 +12,8 @@ from torch.utils.checkpoint import checkpoint
from torch import nn, einsum from torch import nn, einsum
import torchvision.transforms as T import torchvision.transforms as T
from einops import rearrange, repeat, reduce from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
from kornia.filters import gaussian_blur2d from kornia.filters import gaussian_blur2d
import kornia.augmentation as K import kornia.augmentation as K
@@ -100,6 +98,9 @@ def eval_decorator(fn):
return out return out
return inner return inner
def is_float_dtype(dtype):
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
def is_list_str(x): def is_list_str(x):
if not isinstance(x, (list, tuple)): if not isinstance(x, (list, tuple)):
return False return False
@@ -250,9 +251,15 @@ class XClipAdapter(BaseClipAdapter):
text = text[..., :self.max_text_len] text = text[..., :self.max_text_len]
text_mask = text != 0 text_mask = text != 0
encoder_output = self.clip.text_transformer(text) encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
encoder_output_is_cls = encoder_output.ndim == 3
text_cls, text_encodings = (encoder_output[:, 0], encoder_output[:, 1:]) if encoder_output_is_cls else (encoder_output, None)
text_embed = self.clip.to_text_latent(text_cls) text_embed = self.clip.to_text_latent(text_cls)
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
if exists(text_encodings):
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return EmbeddedText(l2norm(text_embed), text_encodings) return EmbeddedText(l2norm(text_embed), text_encodings)
@torch.no_grad() @torch.no_grad()
@@ -308,7 +315,10 @@ class OpenAIClipAdapter(BaseClipAdapter):
self.eos_id = 49407 # for handling 0 being also '!' self.eos_id = 49407 # for handling 0 being also '!'
text_attention_final = self.find_layer('ln_final') text_attention_final = self.find_layer('ln_final')
self.dim_latent_ = text_attention_final.weight.shape[0]
self.handle = text_attention_final.register_forward_hook(self._hook) self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1] self.clip_normalize = preprocess.transforms[-1]
self.cleared = False self.cleared = False
@@ -327,7 +337,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
@property @property
def dim_latent(self): def dim_latent(self):
return 512 return self.dim_latent_
@property @property
def image_size(self): def image_size(self):
@@ -348,6 +358,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
is_eos_id = (text == self.eos_id) is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0 text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True) text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
text_mask = text_mask & (text != 0)
assert not self.cleared assert not self.cleared
text_embed = self.clip.encode_text(text) text_embed = self.clip.encode_text(text)
@@ -377,6 +388,8 @@ class OpenClipAdapter(BaseClipAdapter):
self.eos_id = 49407 self.eos_id = 49407
text_attention_final = self.find_layer('ln_final') text_attention_final = self.find_layer('ln_final')
self._dim_latent = text_attention_final.weight.shape[0]
self.handle = text_attention_final.register_forward_hook(self._hook) self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1] self.clip_normalize = preprocess.transforms[-1]
self.cleared = False self.cleared = False
@@ -396,11 +409,14 @@ class OpenClipAdapter(BaseClipAdapter):
@property @property
def dim_latent(self): def dim_latent(self):
return 512 return self._dim_latent
@property @property
def image_size(self): def image_size(self):
return self.clip.visual.image_size image_size = self.clip.visual.image_size
if isinstance(image_size, tuple):
return max(image_size)
return image_size
@property @property
def image_channels(self): def image_channels(self):
@@ -417,6 +433,7 @@ class OpenClipAdapter(BaseClipAdapter):
is_eos_id = (text == self.eos_id) is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0 text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True) text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
text_mask = text_mask & (text != 0)
assert not self.cleared assert not self.cleared
text_embed = self.clip.encode_text(text) text_embed = self.clip.encode_text(text)
@@ -602,7 +619,7 @@ class NoiseScheduler(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def q_sample(self, x_start, t, noise=None): def q_sample(self, x_start, t, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
return ( return (
@@ -610,6 +627,12 @@ class NoiseScheduler(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
) )
def calculate_v(self, x_start, t, noise = None):
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
)
def q_sample_from_to(self, x_from, from_t, to_t, noise = None): def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
shape = x_from.shape shape = x_from.shape
noise = default(noise, lambda: torch.randn_like(x_from)) noise = default(noise, lambda: torch.randn_like(x_from))
@@ -621,6 +644,12 @@ class NoiseScheduler(nn.Module):
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
def predict_start_from_v(self, x_t, t, v):
return (
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def predict_start_from_noise(self, x_t, t, noise): def predict_start_from_noise(self, x_t, t, noise):
return ( return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
@@ -638,6 +667,23 @@ class NoiseScheduler(nn.Module):
return loss return loss
return loss * extract(self.p2_loss_weight, times, loss.shape) return loss * extract(self.p2_loss_weight, times, loss.shape)
# rearrange image to sequence
class RearrangeToSequence(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
x = rearrange(x, 'b c ... -> b ... c')
x, ps = pack([x], 'b * c')
x = self.fn(x)
x, = unpack(x, ps, 'b * c')
x = rearrange(x, 'b ... c -> b c ...')
return x
# diffusion prior # diffusion prior
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
@@ -836,7 +882,7 @@ class Attention(nn.Module):
# add null key / value for classifier free guidance in prior net # add null key / value for classifier free guidance in prior net
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b) nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2) k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2) v = torch.cat((nv, v), dim = -2)
@@ -873,6 +919,8 @@ class Attention(nn.Module):
# attention # attention
attn = sim.softmax(dim = -1, dtype = torch.float32) attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.type(sim.dtype)
attn = self.dropout(attn) attn = self.dropout(attn)
# aggregate values # aggregate values
@@ -954,6 +1002,8 @@ class DiffusionPriorNetwork(nn.Module):
Rearrange('b (n d) -> b n d', n = num_text_embeds) Rearrange('b (n d) -> b n d', n = num_text_embeds)
) )
self.continuous_embedded_time = not exists(num_timesteps)
self.to_time_embeds = nn.Sequential( self.to_time_embeds = nn.Sequential(
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
Rearrange('b (n d) -> b n d', n = num_time_embeds) Rearrange('b (n d) -> b n d', n = num_time_embeds)
@@ -970,7 +1020,10 @@ class DiffusionPriorNetwork(nn.Module):
# dalle1 learned padding strategy # dalle1 learned padding strategy
self.max_text_len = max_text_len self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
self.null_text_encodings = nn.Parameter(torch.randn(1, max_text_len, dim))
self.null_text_embeds = nn.Parameter(torch.randn(1, num_text_embeds, dim))
self.null_image_embed = nn.Parameter(torch.randn(1, dim))
# whether to use self conditioning, Hinton's group's new ddpm technique # whether to use self conditioning, Hinton's group's new ddpm technique
@@ -987,7 +1040,7 @@ class DiffusionPriorNetwork(nn.Module):
if cond_scale == 1: if cond_scale == 1:
return logits return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1, **kwargs)
return null_logits + (logits - null_logits) * cond_scale return null_logits + (logits - null_logits) * cond_scale
def forward( def forward(
@@ -998,7 +1051,8 @@ class DiffusionPriorNetwork(nn.Module):
text_embed, text_embed,
text_encodings = None, text_encodings = None,
self_cond = None, self_cond = None,
cond_drop_prob = 0. text_cond_drop_prob = 0.,
image_cond_drop_prob = 0.
): ):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
@@ -1016,6 +1070,14 @@ class DiffusionPriorNetwork(nn.Module):
text_embed = self.to_text_embeds(text_embed) text_embed = self.to_text_embeds(text_embed)
image_embed = self.to_image_embeds(image_embed) image_embed = self.to_image_embeds(image_embed)
# classifier free guidance masks
text_keep_mask = prob_mask_like((batch,), 1 - text_cond_drop_prob, device = device)
text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
image_keep_mask = prob_mask_like((batch,), 1 - image_cond_drop_prob, device = device)
image_keep_mask = rearrange(image_keep_mask, 'b -> b 1 1')
# make text encodings optional # make text encodings optional
# although the paper seems to suggest it is present <-- # although the paper seems to suggest it is present <--
@@ -1036,38 +1098,48 @@ class DiffusionPriorNetwork(nn.Module):
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.) text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
mask = F.pad(mask, (0, remainder), value = False) mask = F.pad(mask, (0, remainder), value = False)
null_text_embeds = self.null_text_embed.to(text_encodings.dtype) # mask out text encodings with null encodings
null_text_encodings = self.null_text_encodings.to(text_encodings.dtype)
text_encodings = torch.where( text_encodings = torch.where(
rearrange(mask, 'b n -> b n 1').clone(), rearrange(mask, 'b n -> b n 1').clone() & text_keep_mask,
text_encodings, text_encodings,
null_text_encodings
)
# mask out text embeddings with null text embeddings
null_text_embeds = self.null_text_embeds.to(text_embed.dtype)
text_embed = torch.where(
text_keep_mask,
text_embed,
null_text_embeds null_text_embeds
) )
# classifier free guidance # mask out image embeddings with null image embeddings
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device) null_image_embed = self.null_image_embed.to(image_embed.dtype)
keep_mask = rearrange(keep_mask, 'b -> b 1')
mask &= keep_mask image_embed = torch.where(
image_keep_mask,
# whether text embedding is masked or not depends on the classifier free guidance conditional masking image_embed,
null_image_embed
keep_mask = repeat(keep_mask, 'b 1 -> b n', n = num_text_embeds) )
mask = torch.cat((mask, keep_mask), dim = 1)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right # but let's just do it right
attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds if self.continuous_embedded_time:
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query diffusion_timesteps = diffusion_timesteps.type(dtype)
time_embed = self.to_time_embeds(diffusion_timesteps) time_embed = self.to_time_embeds(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
if self.self_cond: if self.self_cond:
learned_queries = torch.cat((image_embed, self_cond), dim = -2) learned_queries = torch.cat((self_cond, learned_queries), dim = -2)
tokens = torch.cat(( tokens = torch.cat((
text_encodings, text_encodings,
@@ -1099,8 +1171,11 @@ class DiffusionPrior(nn.Module):
timesteps = 1000, timesteps = 1000,
sample_timesteps = None, sample_timesteps = None,
cond_drop_prob = 0., cond_drop_prob = 0.,
text_cond_drop_prob = None,
image_cond_drop_prob = None,
loss_type = "l2", loss_type = "l2",
predict_x_start = True, predict_x_start = True,
predict_v = False,
beta_schedule = "cosine", beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs) sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
@@ -1137,15 +1212,22 @@ class DiffusionPrior(nn.Module):
self.net = net self.net = net
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent) self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
assert net.dim == self.image_embed_dim, f'your diffusion prior network has a dimension of {net.dim}, but you set your image embedding dimension (keyword image_embed_dim) on DiffusionPrior to {self.image_embed_dim}'
assert not exists(clip) or clip.dim_latent == self.image_embed_dim, f'you passed in a CLIP to the diffusion prior with latent dimensions of {clip.dim_latent}, but your image embedding dimension (keyword image_embed_dim) for the DiffusionPrior was set to {self.image_embed_dim}'
self.channels = default(image_channels, lambda: clip.image_channels) self.channels = default(image_channels, lambda: clip.image_channels)
self.cond_drop_prob = cond_drop_prob self.text_cond_drop_prob = default(text_cond_drop_prob, cond_drop_prob)
self.can_classifier_guidance = cond_drop_prob > 0. self.image_cond_drop_prob = default(image_cond_drop_prob, cond_drop_prob)
self.can_classifier_guidance = self.text_cond_drop_prob > 0. and self.image_cond_drop_prob > 0.
self.condition_on_text_encodings = condition_on_text_encodings self.condition_on_text_encodings = condition_on_text_encodings
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
self.predict_x_start = predict_x_start self.predict_x_start = predict_x_start
self.predict_v = predict_v # takes precedence over predict_x_start
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 # @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
@@ -1175,7 +1257,9 @@ class DiffusionPrior(nn.Module):
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond) pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
if self.predict_x_start: if self.predict_v:
x_start = self.noise_scheduler.predict_start_from_v(x, t = t, v = pred)
elif self.predict_x_start:
x_start = pred x_start = pred
else: else:
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
@@ -1224,7 +1308,7 @@ class DiffusionPrior(nn.Module):
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.): def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1] times = torch.linspace(-1., total_timesteps, steps = timesteps + 1)[:-1]
times = list(reversed(times.int().tolist())) times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) time_pairs = list(zip(times[:-1], times[1:]))
@@ -1246,12 +1330,16 @@ class DiffusionPrior(nn.Module):
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond) pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
if self.predict_x_start: # derive x0
if self.predict_v:
x_start = self.noise_scheduler.predict_start_from_v(image_embed, t = time_cond, v = pred)
elif self.predict_x_start:
x_start = pred x_start = pred
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred)
else: else:
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred) x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
pred_noise = pred
# clip x0 before maybe predicting noise
if not self.predict_x_start: if not self.predict_x_start:
x_start.clamp_(-1., 1.) x_start.clamp_(-1., 1.)
@@ -1259,6 +1347,14 @@ class DiffusionPrior(nn.Module):
if self.predict_x_start and self.sampling_clamp_l2norm: if self.predict_x_start and self.sampling_clamp_l2norm:
x_start = self.l2norm_clamp_embed(x_start) x_start = self.l2norm_clamp_embed(x_start)
# predict noise
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
if time_next < 0:
image_embed = x_start
continue
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(image_embed) if time_next > 0 else 0. noise = torch.randn_like(image_embed) if time_next > 0 else 0.
@@ -1300,14 +1396,20 @@ class DiffusionPrior(nn.Module):
image_embed_noisy, image_embed_noisy,
times, times,
self_cond = self_cond, self_cond = self_cond,
cond_drop_prob = self.cond_drop_prob, text_cond_drop_prob = self.text_cond_drop_prob,
image_cond_drop_prob = self.image_cond_drop_prob,
**text_cond **text_cond
) )
if self.predict_x_start and self.training_clamp_l2norm: if self.predict_x_start and self.training_clamp_l2norm:
pred = self.l2norm_clamp_embed(pred) pred = self.l2norm_clamp_embed(pred)
target = noise if not self.predict_x_start else image_embed if self.predict_v:
target = self.noise_scheduler.calculate_v(image_embed, times, noise)
elif self.predict_x_start:
target = image_embed
else:
target = noise
loss = self.noise_scheduler.loss_fn(pred, target) loss = self.noise_scheduler.loss_fn(pred, target)
return loss return loss
@@ -1377,7 +1479,7 @@ class DiffusionPrior(nn.Module):
**kwargs **kwargs
): ):
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied' assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied' assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization' assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
if exists(image): if exists(image):
@@ -1447,9 +1549,14 @@ class PixelShuffleUpsample(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
def Downsample(dim, *, dim_out = None): def Downsample(dim, dim_out = None):
# https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample
# named SP-conv in the paper, but basically a pixel unshuffle
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
return nn.Conv2d(dim, dim_out, 4, 2, 1) return nn.Sequential(
Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
nn.Conv2d(dim * 4, dim_out, 1)
)
class WeightStandardizedConv2d(nn.Conv2d): class WeightStandardizedConv2d(nn.Conv2d):
""" """
@@ -1478,6 +1585,8 @@ class SinusoidalPosEmb(nn.Module):
def forward(self, x): def forward(self, x):
dtype, device = x.dtype, x.device dtype, device = x.dtype, x.device
assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'
half_dim = self.dim // 2 half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1) emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb) emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
@@ -1535,14 +1644,10 @@ class ResnetBlock(nn.Module):
self.cross_attn = None self.cross_attn = None
if exists(cond_dim): if exists(cond_dim):
self.cross_attn = EinopsToAndFrom( self.cross_attn = CrossAttention(
'b c h w', dim = dim_out,
'b (h w) c', context_dim = cond_dim,
CrossAttention( cosine_sim = cosine_sim_cross_attn
dim = dim_out,
context_dim = cond_dim,
cosine_sim = cosine_sim_cross_attn
)
) )
self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization) self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
@@ -1561,8 +1666,15 @@ class ResnetBlock(nn.Module):
if exists(self.cross_attn): if exists(self.cross_attn):
assert exists(cond) assert exists(cond)
h = rearrange(h, 'b c ... -> b ... c')
h, ps = pack([h], 'b * c')
h = self.cross_attn(h, context = cond) + h h = self.cross_attn(h, context = cond) + h
h, = unpack(h, ps, 'b * c')
h = rearrange(h, 'b ... c -> b c ...')
h = self.block2(h) h = self.block2(h)
return h + self.res_conv(x) return h + self.res_conv(x)
@@ -1608,11 +1720,11 @@ class CrossAttention(nn.Module):
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
# add null key / value for classifier free guidance in prior net # add null key / value for classifier free guidance in prior net
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b) nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2) k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2) v = torch.cat((nv, v), dim = -2)
@@ -1631,6 +1743,7 @@ class CrossAttention(nn.Module):
sim = sim.masked_fill(~mask, max_neg_value) sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1, dtype = torch.float32) attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.type(sim.dtype)
out = einsum('b h i j, b h j d -> b h i d', attn, v) out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, 'b h n d -> b n (h d)')
@@ -1664,7 +1777,7 @@ class LinearAttention(nn.Module):
fmap = self.norm(fmap) fmap = self.norm(fmap)
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1) q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h) q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))
q = q.softmax(dim = -1) q = q.softmax(dim = -1)
k = k.softmax(dim = -2) k = k.softmax(dim = -2)
@@ -1898,7 +2011,7 @@ class Unet(nn.Module):
self_attn = cast_tuple(self_attn, num_stages) self_attn = cast_tuple(self_attn, num_stages)
create_self_attn = lambda dim: EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(dim, **attn_kwargs))) create_self_attn = lambda dim: RearrangeToSequence(Residual(Attention(dim, **attn_kwargs)))
# resnet block klass # resnet block klass
@@ -2375,6 +2488,7 @@ class Decoder(nn.Module):
loss_type = 'l2', loss_type = 'l2',
beta_schedule = None, beta_schedule = None,
predict_x_start = False, predict_x_start = False,
predict_v = False,
predict_x_start_for_latent_diffusion = False, predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops) random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
@@ -2397,7 +2511,7 @@ class Decoder(nn.Module):
dynamic_thres_percentile = 0.95, dynamic_thres_percentile = 0.95,
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1, p2_loss_weight_k = 1,
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict ddim_sampling_eta = 0. # can be set to 0. for deterministic sampling afaict
): ):
super().__init__() super().__init__()
@@ -2547,6 +2661,10 @@ class Decoder(nn.Module):
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes)) self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
# predict v
self.predict_v = cast_tuple(predict_v, len(unets))
# input image range # input image range
self.input_image_range = (-1. if not auto_normalize_img else 0., 1.) self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
@@ -2627,11 +2745,16 @@ class Decoder(nn.Module):
if exists(unet_number): if exists(unet_number):
unet = self.get_unet(unet_number) unet = self.get_unet(unet_number)
# devices
cuda, cpu = torch.device('cuda'), torch.device('cpu')
self.cuda() self.cuda()
devices = [module_device(unet) for unet in self.unets] devices = [module_device(unet) for unet in self.unets]
self.unets.cpu()
unet.cuda() self.unets.to(cpu)
unet.to(cuda)
yield yield
@@ -2658,14 +2781,16 @@ class Decoder(nn.Module):
x = x.clamp(-s, s) / s x = x.clamp(-s, s) / s
return x return x
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None): def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, predict_v = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level)) model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output) pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
if predict_x_start: if predict_v:
x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
elif predict_x_start:
x_start = pred x_start = pred
else: else:
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
@@ -2692,9 +2817,9 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance, x_start return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad() @torch.no_grad()
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None): def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, predict_v = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level) model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
noise = torch.randn_like(x) noise = torch.randn_like(x)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
@@ -2709,6 +2834,7 @@ class Decoder(nn.Module):
image_embed, image_embed,
noise_scheduler, noise_scheduler,
predict_x_start = False, predict_x_start = False,
predict_v = False,
learned_variance = False, learned_variance = False,
clip_denoised = True, clip_denoised = True,
lowres_cond_img = None, lowres_cond_img = None,
@@ -2767,6 +2893,7 @@ class Decoder(nn.Module):
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level, lowres_noise_level = lowres_noise_level,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
predict_v = predict_v,
noise_scheduler = noise_scheduler, noise_scheduler = noise_scheduler,
learned_variance = learned_variance, learned_variance = learned_variance,
clip_denoised = clip_denoised clip_denoised = clip_denoised
@@ -2792,6 +2919,7 @@ class Decoder(nn.Module):
timesteps, timesteps,
eta = 1., eta = 1.,
predict_x_start = False, predict_x_start = False,
predict_v = False,
learned_variance = False, learned_variance = False,
clip_denoised = True, clip_denoised = True,
lowres_cond_img = None, lowres_cond_img = None,
@@ -2803,12 +2931,13 @@ class Decoder(nn.Module):
inpaint_mask = None, inpaint_mask = None,
inpaint_resample_times = 5 inpaint_resample_times = 5
): ):
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod, self.ddim_sampling_eta
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1] times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
times = list(reversed(times.int().tolist())) times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) time_pairs = list(zip(times[:-1], times[1:]))
time_pairs = list(filter(lambda t: t[0] > t[1], time_pairs))
is_inpaint = exists(inpaint_image) is_inpaint = exists(inpaint_image)
resample_times = inpaint_resample_times if is_inpaint else 1 resample_times = inpaint_resample_times if is_inpaint else 1
@@ -2850,16 +2979,24 @@ class Decoder(nn.Module):
pred, _ = self.parse_unet_output(learned_variance, unet_output) pred, _ = self.parse_unet_output(learned_variance, unet_output)
if predict_x_start: # predict x0
if predict_v:
x_start = noise_scheduler.predict_start_from_v(img, t = time_cond, v = pred)
elif predict_x_start:
x_start = pred x_start = pred
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else: else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred) x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
pred_noise = pred
# maybe clip x0
if clip_denoised: if clip_denoised:
x_start = self.dynamic_threshold(x_start) x_start = self.dynamic_threshold(x_start)
# predict noise
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if not is_last_timestep else 0. noise = torch.randn_like(img) if not is_last_timestep else 0.
@@ -2892,7 +3029,7 @@ class Decoder(nn.Module):
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs) return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None): def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, predict_v = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1] # normalize to [-1, 1]
@@ -2937,7 +3074,12 @@ class Decoder(nn.Module):
pred, _ = self.parse_unet_output(learned_variance, unet_output) pred, _ = self.parse_unet_output(learned_variance, unet_output)
target = noise if not predict_x_start else x_start if predict_v:
target = noise_scheduler.calculate_v(x_start, times, noise)
elif predict_x_start:
target = x_start
else:
target = noise
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none') loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b (...)', 'mean') loss = reduce(loss, 'b ... -> b (...)', 'mean')
@@ -2995,7 +3137,8 @@ class Decoder(nn.Module):
distributed = False, distributed = False,
inpaint_image = None, inpaint_image = None,
inpaint_mask = None, inpaint_mask = None,
inpaint_resample_times = 5 inpaint_resample_times = 5,
one_unet_in_gpu_at_time = True
): ):
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally' assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
@@ -3018,16 +3161,17 @@ class Decoder(nn.Module):
assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size) assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size)
prev_unet_output_size = self.image_sizes[start_at_unet_number - 2] prev_unet_output_size = self.image_sizes[start_at_unet_number - 2]
img = resize_image_to(image, prev_unet_output_size, nearest = True) img = resize_image_to(image, prev_unet_output_size, nearest = True)
is_cuda = next(self.parameters()).is_cuda is_cuda = next(self.parameters()).is_cuda
num_unets = self.num_unets num_unets = self.num_unets
cond_scale = cast_tuple(cond_scale, num_unets) cond_scale = cast_tuple(cond_scale, num_unets)
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)): for unet_number, unet, vae, channel, image_size, predict_x_start, predict_v, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.predict_v, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
if unet_number < start_at_unet_number: if unet_number < start_at_unet_number:
continue # It's the easiest way to do it continue # It's the easiest way to do it
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context() context = self.one_unet_in_gpu(unet = unet) if is_cuda and one_unet_in_gpu_at_time else null_context()
with context: with context:
# prepare low resolution conditioning for upsamplers # prepare low resolution conditioning for upsamplers
@@ -3059,6 +3203,7 @@ class Decoder(nn.Module):
text_encodings = text_encodings, text_encodings = text_encodings,
cond_scale = unet_cond_scale, cond_scale = unet_cond_scale,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
predict_v = predict_v,
learned_variance = learned_variance, learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion, clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
@@ -3098,11 +3243,12 @@ class Decoder(nn.Module):
lowres_conditioner = self.lowres_conds[unet_index] lowres_conditioner = self.lowres_conds[unet_index]
target_image_size = self.image_sizes[unet_index] target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index] predict_x_start = self.predict_x_start[unet_index]
predict_v = self.predict_v[unet_index]
random_crop_size = self.random_crop_sizes[unet_index] random_crop_size = self.random_crop_sizes[unet_index]
learned_variance = self.learned_variance[unet_index] learned_variance = self.learned_variance[unet_index]
b, c, h, w, device, = *image.shape, image.device b, c, h, w, device, = *image.shape, image.device
check_shape(image, 'b c h w', c = self.channels) assert image.shape[1] == self.channels
assert h >= target_image_size and w >= target_image_size assert h >= target_image_size and w >= target_image_size
times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long) times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)
@@ -3136,7 +3282,7 @@ class Decoder(nn.Module):
image = vae.encode(image) image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img) lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level) losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
if not return_lowres_cond_image: if not return_lowres_cond_image:
return losses return losses

View File

@@ -1,14 +1,16 @@
import json import json
from torchvision import transforms as T from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator from pydantic import BaseModel, validator, model_validator
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
from x_clip import CLIP as XCLIP from x_clip import CLIP as XCLIP
from open_clip import list_pretrained
from coca_pytorch import CoCa from coca_pytorch import CoCa
from dalle2_pytorch.dalle2_pytorch import ( from dalle2_pytorch.dalle2_pytorch import (
CoCaAdapter, CoCaAdapter,
OpenAIClipAdapter, OpenAIClipAdapter,
OpenClipAdapter,
Unet, Unet,
Decoder, Decoder,
DiffusionPrior, DiffusionPrior,
@@ -36,12 +38,12 @@ class TrainSplitConfig(BaseModel):
val: float = 0.15 val: float = 0.15
test: float = 0.1 test: float = 0.1
@root_validator @model_validator(mode = 'after')
def validate_all(cls, fields): def validate_all(self, m):
actual_sum = sum([*fields.values()]) actual_sum = sum([*dict(self).values()])
if actual_sum != 1.: if actual_sum != 1.:
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}') raise ValueError(f'{dict(self).keys()} must sum to 1.0. Found: {actual_sum}')
return fields return self
class TrackerLogConfig(BaseModel): class TrackerLogConfig(BaseModel):
log_type: str = 'console' log_type: str = 'console'
@@ -57,6 +59,7 @@ class TrackerLogConfig(BaseModel):
kwargs = self.dict() kwargs = self.dict()
return create_logger(self.log_type, data_path, **kwargs) return create_logger(self.log_type, data_path, **kwargs)
class TrackerLoadConfig(BaseModel): class TrackerLoadConfig(BaseModel):
load_from: Optional[str] = None load_from: Optional[str] = None
only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming
@@ -87,7 +90,7 @@ class TrackerConfig(BaseModel):
data_path: str = '.tracker_data' data_path: str = '.tracker_data'
overwrite_data_path: bool = False overwrite_data_path: bool = False
log: TrackerLogConfig log: TrackerLogConfig
load: Optional[TrackerLoadConfig] load: Optional[TrackerLoadConfig] = None
save: Union[List[TrackerSaveConfig], TrackerSaveConfig] save: Union[List[TrackerSaveConfig], TrackerSaveConfig]
def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker: def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
@@ -112,11 +115,15 @@ class TrackerConfig(BaseModel):
class AdapterConfig(BaseModel): class AdapterConfig(BaseModel):
make: str = "openai" make: str = "openai"
model: str = "ViT-L/14" model: str = "ViT-L/14"
base_model_kwargs: Dict[str, Any] = None base_model_kwargs: Optional[Dict[str, Any]] = None
def create(self): def create(self):
if self.make == "openai": if self.make == "openai":
return OpenAIClipAdapter(self.model) return OpenAIClipAdapter(self.model)
elif self.make == "open_clip":
pretrained = dict(list_pretrained())
checkpoint = pretrained[self.model]
return OpenClipAdapter(name=self.model, pretrained=checkpoint)
elif self.make == "x-clip": elif self.make == "x-clip":
return XClipAdapter(XCLIP(**self.base_model_kwargs)) return XClipAdapter(XCLIP(**self.base_model_kwargs))
elif self.make == "coca": elif self.make == "coca":
@@ -127,8 +134,8 @@ class AdapterConfig(BaseModel):
class DiffusionPriorNetworkConfig(BaseModel): class DiffusionPriorNetworkConfig(BaseModel):
dim: int dim: int
depth: int depth: int
max_text_len: int = None max_text_len: Optional[int] = None
num_timesteps: int = None num_timesteps: Optional[int] = None
num_time_embeds: int = 1 num_time_embeds: int = 1
num_image_embeds: int = 1 num_image_embeds: int = 1
num_text_embeds: int = 1 num_text_embeds: int = 1
@@ -151,7 +158,7 @@ class DiffusionPriorNetworkConfig(BaseModel):
return DiffusionPriorNetwork(**kwargs) return DiffusionPriorNetwork(**kwargs)
class DiffusionPriorConfig(BaseModel): class DiffusionPriorConfig(BaseModel):
clip: AdapterConfig = None clip: Optional[AdapterConfig] = None
net: DiffusionPriorNetworkConfig net: DiffusionPriorNetworkConfig
image_embed_dim: int image_embed_dim: int
image_size: int image_size: int
@@ -188,7 +195,7 @@ class DiffusionPriorTrainConfig(BaseModel):
use_ema: bool = True use_ema: bool = True
ema_beta: float = 0.99 ema_beta: float = 0.99
amp: bool = False amp: bool = False
warmup_steps: int = None # number of warmup steps warmup_steps: Optional[int] = None # number of warmup steps
save_every_seconds: int = 3600 # how often to save save_every_seconds: int = 3600 # how often to save
eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with
best_validation_loss: float = 1e9 # the current best valudation loss observed best_validation_loss: float = 1e9 # the current best valudation loss observed
@@ -221,12 +228,12 @@ class TrainDiffusionPriorConfig(BaseModel):
class UnetConfig(BaseModel): class UnetConfig(BaseModel):
dim: int dim: int
dim_mults: ListOrTuple[int] dim_mults: ListOrTuple[int]
image_embed_dim: int = None image_embed_dim: Optional[int] = None
text_embed_dim: int = None text_embed_dim: Optional[int] = None
cond_on_text_encodings: bool = None cond_on_text_encodings: Optional[bool] = None
cond_dim: int = None cond_dim: Optional[int] = None
channels: int = 3 channels: int = 3
self_attn: ListOrTuple[int] self_attn: SingularOrIterable[bool] = False
attn_dim_head: int = 32 attn_dim_head: int = 32
attn_heads: int = 16 attn_heads: int = 16
init_cross_embed: bool = True init_cross_embed: bool = True
@@ -236,14 +243,14 @@ class UnetConfig(BaseModel):
class DecoderConfig(BaseModel): class DecoderConfig(BaseModel):
unets: ListOrTuple[UnetConfig] unets: ListOrTuple[UnetConfig]
image_size: int = None image_size: Optional[int] = None
image_sizes: ListOrTuple[int] = None image_sizes: ListOrTuple[int] = None
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided clip: Optional[AdapterConfig] = None # The clip model to use if embeddings are not provided
channels: int = 3 channels: int = 3
timesteps: int = 1000 timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable[int]] = None sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None
loss_type: str = 'l2' loss_type: str = 'l2'
beta_schedule: ListOrTuple[str] = None # None means all cosine beta_schedule: Optional[ListOrTuple[str]] = None # None means all cosine
learned_variance: SingularOrIterable[bool] = True learned_variance: SingularOrIterable[bool] = True
image_cond_drop_prob: float = 0.1 image_cond_drop_prob: float = 0.1
text_cond_drop_prob: float = 0.5 text_cond_drop_prob: float = 0.5
@@ -271,9 +278,9 @@ class DecoderConfig(BaseModel):
extra = "allow" extra = "allow"
class DecoderDataConfig(BaseModel): class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images webdataset_base_url: str # path to a webdataset with jpg images
img_embeddings_url: Optional[str] # path to .npy files with embeddings img_embeddings_url: Optional[str] = None # path to .npy files with embeddings
text_embeddings_url: Optional[str] # path to .npy files with embeddings text_embeddings_url: Optional[str] = None # path to .npy files with embeddings
num_workers: int = 4 num_workers: int = 4
batch_size: int = 64 batch_size: int = 64
start_shard: int = 0 start_shard: int = 0
@@ -307,25 +314,26 @@ class DecoderTrainConfig(BaseModel):
wd: SingularOrIterable[float] = 0.01 wd: SingularOrIterable[float] = 0.01
warmup_steps: Optional[SingularOrIterable[int]] = None warmup_steps: Optional[SingularOrIterable[int]] = None
find_unused_parameters: bool = True find_unused_parameters: bool = True
static_graph: bool = True
max_grad_norm: SingularOrIterable[float] = 0.5 max_grad_norm: SingularOrIterable[float] = 0.5
save_every_n_samples: int = 100000 save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
cond_scale: Union[float, List[float]] = 1.0 cond_scale: Union[float, List[float]] = 1.0
device: str = 'cuda:0' device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite. epoch_samples: Optional[int] = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation. validation_samples: Optional[int] = None # Same as above but for validation.
save_immediately: bool = False save_immediately: bool = False
use_ema: bool = True use_ema: bool = True
ema_beta: float = 0.999 ema_beta: float = 0.999
amp: bool = False amp: bool = False
unet_training_mask: ListOrTuple[bool] = None # If None, use all unets unet_training_mask: Optional[ListOrTuple[bool]] = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel): class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000 n_evaluation_samples: int = 1000
FID: Dict[str, Any] = None FID: Optional[Dict[str, Any]] = None
IS: Dict[str, Any] = None IS: Optional[Dict[str, Any]] = None
KID: Dict[str, Any] = None KID: Optional[Dict[str, Any]] = None
LPIPS: Dict[str, Any] = None LPIPS: Optional[Dict[str, Any]] = None
class TrainDecoderConfig(BaseModel): class TrainDecoderConfig(BaseModel):
decoder: DecoderConfig decoder: DecoderConfig
@@ -339,11 +347,14 @@ class TrainDecoderConfig(BaseModel):
def from_json_path(cls, json_path): def from_json_path(cls, json_path):
with open(json_path) as f: with open(json_path) as f:
config = json.load(f) config = json.load(f)
print(config)
return cls(**config) return cls(**config)
@root_validator @model_validator(mode = 'after')
def check_has_embeddings(cls, values): def check_has_embeddings(self, m):
# Makes sure that enough information is provided to get the embeddings specified for training # Makes sure that enough information is provided to get the embeddings specified for training
values = dict(self)
data_config, decoder_config = values.get('data'), values.get('decoder') data_config, decoder_config = values.get('data'), values.get('decoder')
if not exists(data_config) or not exists(decoder_config): if not exists(data_config) or not exists(decoder_config):
@@ -368,4 +379,4 @@ class TrainDecoderConfig(BaseModel):
if text_emb_url: if text_emb_url:
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason." assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
return values return m

View File

@@ -181,7 +181,8 @@ class DiffusionPriorTrainer(nn.Module):
eps = 1e-6, eps = 1e-6,
max_grad_norm = None, max_grad_norm = None,
group_wd_params = True, group_wd_params = True,
warmup_steps = 1, warmup_steps = None,
cosine_decay_max_steps = None,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -234,7 +235,10 @@ class DiffusionPriorTrainer(nn.Module):
**kwargs **kwargs
) )
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0) if exists(cosine_decay_max_steps):
self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps)
else:
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
@@ -271,6 +275,7 @@ class DiffusionPriorTrainer(nn.Module):
# FIXME: LambdaLR can't be saved due to pickling issues # FIXME: LambdaLR can't be saved due to pickling issues
save_obj = dict( save_obj = dict(
optimizer = self.optimizer.state_dict(), optimizer = self.optimizer.state_dict(),
scheduler = self.scheduler.state_dict(),
warmup_scheduler = self.warmup_scheduler, warmup_scheduler = self.warmup_scheduler,
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(), model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
version = version.parse(__version__), version = version.parse(__version__),
@@ -317,7 +322,9 @@ class DiffusionPriorTrainer(nn.Module):
# unwrap the model when loading from checkpoint # unwrap the model when loading from checkpoint
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict) self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device)) self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
self.optimizer.load_state_dict(loaded_obj['optimizer']) self.optimizer.load_state_dict(loaded_obj['optimizer'])
self.scheduler.load_state_dict(loaded_obj['scheduler'])
# set warmupstep # set warmupstep
if exists(self.warmup_scheduler): if exists(self.warmup_scheduler):
@@ -350,7 +357,8 @@ class DiffusionPriorTrainer(nn.Module):
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy" # accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
if not self.accelerator.optimizer_step_was_skipped: if not self.accelerator.optimizer_step_was_skipped:
with self.warmup_scheduler.dampening(): sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
with sched_context():
self.scheduler.step() self.scheduler.step()
if self.use_ema: if self.use_ema:

View File

@@ -1 +1 @@
__version__ = '1.8.0' __version__ = '1.15.6'

View File

@@ -11,8 +11,7 @@ import torch.nn.functional as F
from torch.autograd import grad as torch_grad from torch.autograd import grad as torch_grad
import torchvision import torchvision
from einops import rearrange, reduce, repeat from einops import rearrange, reduce, repeat, pack, unpack
from einops_exts import rearrange_many
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
# constants # constants
@@ -408,7 +407,7 @@ class Attention(nn.Module):
x = self.norm(x) x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim = -1) q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q = q * self.scale q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k) sim = einsum('b h i d, b h j d -> b h i j', q, k)

View File

@@ -26,17 +26,17 @@ setup(
install_requires=[ install_requires=[
'accelerate', 'accelerate',
'click', 'click',
'clip-anytorch>=2.4.0', 'open-clip-torch>=2.0.0,<3.0.0',
'clip-anytorch>=2.5.2',
'coca-pytorch>=0.0.5', 'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7', 'ema-pytorch>=0.0.7',
'einops>=0.4', 'einops>=0.7.0',
'einops-exts>=0.0.3',
'embedding-reader', 'embedding-reader',
'kornia>=0.5.4', 'kornia>=0.5.4',
'numpy', 'numpy',
'packaging', 'packaging',
'pillow', 'pillow',
'pydantic', 'pydantic>=2',
'pytorch-warmup', 'pytorch-warmup',
'resize-right>=0.0.2', 'resize-right>=0.0.2',
'rotary-embedding-torch', 'rotary-embedding-torch',

View File

@@ -134,7 +134,7 @@ def get_example_data(dataloader, device, n=5):
break break
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n])) return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True): def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
""" """
Takes example data and generates images from the embeddings Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions Returns three lists: real images, generated images, and captions
@@ -144,7 +144,9 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
if img_embeddings[0] is None: if img_embeddings[0] is None:
# Generate image embeddings from clip # Generate image embeddings from clip
imgs_tensor = torch.stack(real_images) imgs_tensor = torch.stack(real_images)
img_embeddings, *_ = trainer.embed_image(imgs_tensor) assert clip is not None, "clip is None, but img_embeddings is None"
imgs_tensor.to(device=device)
img_embeddings, img_encoding = clip.embed_image(imgs_tensor)
sample_params["image_embed"] = img_embeddings sample_params["image_embed"] = img_embeddings
else: else:
# Then we are using precomputed image embeddings # Then we are using precomputed image embeddings
@@ -153,8 +155,10 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
if condition_on_text_encodings: if condition_on_text_encodings:
if text_embeddings[0] is None: if text_embeddings[0] is None:
# Generate text embeddings from text # Generate text embeddings from text
tokenized_texts = tokenize(txts, truncate=True) assert clip is not None, "clip is None, but text_embeddings is None"
sample_params["text"] = tokenized_texts tokenized_texts = tokenize(txts, truncate=True).to(device=device)
text_embed, text_encodings = clip.embed_text(tokenized_texts)
sample_params["text_encodings"] = text_encodings
else: else:
# Then we are using precomputed text embeddings # Then we are using precomputed text embeddings
text_embeddings = torch.stack(text_embeddings) text_embeddings = torch.stack(text_embeddings)
@@ -166,7 +170,7 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
sample_params["image"] = torch.stack(real_images) sample_params["image"] = torch.stack(real_images)
if device is not None: if device is not None:
sample_params["_device"] = device sample_params["_device"] = device
samples = trainer.sample(**sample_params) samples = trainer.sample(**sample_params, _cast_deepspeed_precision=False) # At sampling time we don't want to cast to FP16
generated_images = list(samples) generated_images = list(samples)
captions = [text_prepend + txt for txt in txts] captions = [text_prepend + txt for txt in txts]
if match_image_size: if match_image_size:
@@ -174,15 +178,15 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images] real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
return real_images, generated_images, captions return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""): def generate_grid_samples(trainer, examples, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
""" """
Generates samples and uses torchvision to put them in a side by side grid for easy viewing Generates samples and uses torchvision to put them in a side by side grid for easy viewing
""" """
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend) real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=None, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
""" """
Computes evaluation metrics for the decoder Computes evaluation metrics for the decoder
""" """
@@ -192,7 +196,7 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, conditi
if len(examples) == 0: if len(examples) == 0:
print("No data to evaluate. Check that your dataloader has shards.") print("No data to evaluate. Check that your dataloader has shards.")
return metrics return metrics
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device) real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float) real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float) generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8 # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
@@ -225,8 +229,8 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, conditi
metrics["KID_std"] = kid_std.item() metrics["KID_std"] = kid_std.item()
if exists(LPIPS): if exists(LPIPS):
# Convert from [0, 1] to [-1, 1] # Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1) renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1)
renorm_generated_images = generated_images.mul(2).sub(1) renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync) lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
lpips.to(device=device) lpips.to(device=device)
lpips.update(renorm_real_images, renorm_generated_images) lpips.update(renorm_real_images, renorm_generated_images)
@@ -265,6 +269,7 @@ def train(
accelerator: Accelerator, accelerator: Accelerator,
tracker: Tracker, tracker: Tracker,
inference_device, inference_device,
clip=None,
evaluate_config=None, evaluate_config=None,
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
validation_samples = None, validation_samples = None,
@@ -371,15 +376,19 @@ def train(
forward_params['image_embed'] = img_emb forward_params['image_embed'] = img_emb
else: else:
# Forward pass automatically generates embedding # Forward pass automatically generates embedding
pass assert clip is not None
img_embed, img_encoding = clip.embed_image(img)
forward_params['image_embed'] = img_embed
if condition_on_text_encodings: if condition_on_text_encodings:
if has_text_embedding: if has_text_embedding:
forward_params['text_encodings'] = text_emb forward_params['text_encodings'] = text_emb
else: else:
# Then we need to pass the text instead # Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True) assert clip is not None
tokenized_texts = tokenize(txt, truncate=True).to(inference_device)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})" assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts text_embed, text_encodings = clip.embed_text(tokenized_texts)
forward_params['text_encodings'] = text_encodings
loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device) loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)
trainer.update(unet_number=unet) trainer.update(unet_number=unet)
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
@@ -419,7 +428,7 @@ def train(
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen) save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
if exists(n_sample_images) and n_sample_images > 0: if exists(n_sample_images) and n_sample_images > 0:
trainer.eval() trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ") train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
if epoch_samples is not None and sample >= epoch_samples: if epoch_samples is not None and sample >= epoch_samples:
@@ -462,15 +471,19 @@ def train(
forward_params['image_embed'] = img_emb.float() forward_params['image_embed'] = img_emb.float()
else: else:
# Forward pass automatically generates embedding # Forward pass automatically generates embedding
pass assert clip is not None
img_embed, img_encoding = clip.embed_image(img)
forward_params['image_embed'] = img_embed
if condition_on_text_encodings: if condition_on_text_encodings:
if has_text_embedding: if has_text_embedding:
forward_params['text_encodings'] = text_emb.float() forward_params['text_encodings'] = text_emb.float()
else: else:
# Then we need to pass the text instead # Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True) assert clip is not None
tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})" assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts text_embed, text_encodings = clip.embed_text(tokenized_texts)
forward_params['text_encodings'] = text_encodings
loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device) loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)
average_val_loss_tensor[0, unet-1] += loss average_val_loss_tensor[0, unet-1] += loss
@@ -498,7 +511,7 @@ def train(
if next_task == 'eval': if next_task == 'eval':
if exists(evaluate_config): if exists(evaluate_config):
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale) evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, clip=clip, inference_device=inference_device, **evaluate_config.model_dump(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
if is_master: if is_master:
tracker.log(evaluation, step=step()) tracker.log(evaluation, step=step())
next_task = 'sample' next_task = 'sample'
@@ -509,8 +522,8 @@ def train(
# Generate examples and save the model if we are the master # Generate examples and save the model if we are the master
# Generate sample images # Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40)) print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ") test_images, test_captions = generate_grid_samples(trainer, test_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ") train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step()) tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
@@ -532,9 +545,10 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
"NumProcesses": accelerator.num_processes, "NumProcesses": accelerator.num_processes,
"MixedPrecision": accelerator.mixed_precision "MixedPrecision": accelerator.mixed_precision
} }
accelerator.wait_for_everyone() # If nodes arrive at this point at different times they might try to autoresume the current run which makes no sense and will cause errors
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy) tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
tracker.save_config(config_path, config_name='decoder_config.json') tracker.save_config(config_path, config_name='decoder_config.json')
tracker.add_save_metadata(state_dict_key='config', metadata=config.dict()) tracker.add_save_metadata(state_dict_key='config', metadata=config.model_dump())
return tracker return tracker
def initialize_training(config: TrainDecoderConfig, config_path): def initialize_training(config: TrainDecoderConfig, config_path):
@@ -542,7 +556,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
torch.manual_seed(config.seed) torch.manual_seed(config.seed)
# Set up accelerator for configurable distributed training # Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph)
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60)) init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs]) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
@@ -556,10 +570,6 @@ def initialize_training(config: TrainDecoderConfig, config_path):
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance: if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
raise ValueError("DeepSpeed fp16 mode does not support learned variance") raise ValueError("DeepSpeed fp16 mode does not support learned variance")
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
# This is an invalid configuration until we figure out how to handle this
raise ValueError("DeepSpeed does not support multi-node distributed training")
# Set up data # Set up data
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1)) all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
world_size = accelerator.num_processes world_size = accelerator.num_processes
@@ -567,6 +577,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
shards_per_process = len(all_shards) // world_size shards_per_process = len(all_shards) // world_size
assert shards_per_process > 0, "Not enough shards to split evenly" assert shards_per_process > 0, "Not enough shards to split evenly"
my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process] my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process]
dataloaders = create_dataloaders ( dataloaders = create_dataloaders (
available_shards=my_shards, available_shards=my_shards,
img_preproc = config.data.img_preproc, img_preproc = config.data.img_preproc,
@@ -574,11 +585,16 @@ def initialize_training(config: TrainDecoderConfig, config_path):
val_prop = config.data.splits.val, val_prop = config.data.splits.val,
test_prop = config.data.splits.test, test_prop = config.data.splits.test,
n_sample_images=config.train.n_sample_images, n_sample_images=config.train.n_sample_images,
**config.data.dict(), **config.data.model_dump(),
rank = rank, rank = rank,
seed = config.seed, seed = config.seed,
) )
# If clip is in the model, we need to remove it for compatibility with deepspeed
clip = None
if config.decoder.clip is not None:
clip = config.decoder.clip.create() # Of course we keep it to use it during training, just not in the decoder as that causes issues
config.decoder.clip = None
# Create the decoder model and print basic info # Create the decoder model and print basic info
decoder = config.decoder.create() decoder = config.decoder.create()
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training)) get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
@@ -590,7 +606,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
has_text_embeddings = config.data.text_embeddings_url is not None has_text_embeddings = config.data.text_embeddings_url is not None
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets]) conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
has_clip_model = config.decoder.clip is not None has_clip_model = clip is not None
data_source_string = "" data_source_string = ""
if has_img_embeddings: if has_img_embeddings:
@@ -615,11 +631,12 @@ def initialize_training(config: TrainDecoderConfig, config_path):
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training") accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
train(dataloaders, decoder, accelerator, train(dataloaders, decoder, accelerator,
clip=clip,
tracker=tracker, tracker=tracker,
inference_device=accelerator.device, inference_device=accelerator.device,
evaluate_config=config.evaluate, evaluate_config=config.evaluate,
condition_on_text_encodings=conditioning_on_text, condition_on_text_encodings=conditioning_on_text,
**config.train.dict(), **config.train.model_dump(),
) )
# Create a simple click command line interface to load the config and start the training # Create a simple click command line interface to load the config and start the training