Compare commits

...

32 Commits

Author SHA1 Message Date
Phil Wang
683dd98b96 extra insurance in case eos id is not there 2022-12-15 10:54:21 -08:00
Phil Wang
067ac323da address https://github.com/lucidrains/DALLE2-pytorch/issues/266 2022-11-23 08:41:25 -08:00
zion
91c8d1ca13 bug fix cosine annealing optimizer in prior trainer (#262) 2022-11-11 12:15:13 -08:00
zion
08238a7200 depend on open-clip-torch (#261)
fix the previous commit which assumes open_clip is installed
2022-11-07 16:19:08 -08:00
zion
7166ad6711 add open clip to train_config (#260)
add the ability to use open_clip in the train configs (useful for the new SOTA h/14 model)
2022-11-07 15:44:36 -08:00
Phil Wang
fbba0f9aaf bring in prediction of v objective, combining the findings from progressive distillation paper and imagen-video to the eventual extension of dalle2 to make-a-video 2022-10-28 18:21:07 -07:00
Romain Beaumont
9f37705d87 Add static graph param (#226)
* Add static graph param

* use static graph param
2022-10-25 19:31:29 +02:00
Phil Wang
c3df46e374 fix openclipadapter to be able to use latest open sourced sota model 2022-10-23 15:12:09 -07:00
Phil Wang
41fabf2922 fix a dtype conversion issue for the diffusion timesteps in the diffusion prior, thanks to @JiaHeng-DLUT 2022-10-19 09:26:06 -07:00
Heng Jia
5975e8222b Fix assert message (#253) 2022-10-18 08:50:59 -07:00
Phil Wang
c18c080128 fix for use with larger openai clip models by extracting dimension of last layernorm in clip 2022-09-29 09:09:47 -07:00
Phil Wang
b39653cf96 fix readme dataloader example 2022-09-20 08:39:52 -07:00
Phil Wang
39f8b6cf16 show example of using SOTA open sourced open clip 2022-09-19 10:45:20 -07:00
Phil Wang
d0c11b30b0 handle open clip adapter image size being a tuple 2022-09-19 10:27:14 -07:00
zion
86e2d5ba84 Minor Decoder Train Script Fixes (#242)
* ensure tokenized text is on proper device
* fix lpips mage distribution
2022-09-15 17:21:48 -07:00
Phil Wang
0d82dff9c5 in ddim, noise should be predicted after x0 is maybe clipped, thanks to @lukovnikov for pointing this out in another repository 2022-09-01 09:40:47 -07:00
Phil Wang
8bbc956ff1 fix bug with misnamed variable in diffusion prior network 2022-08-31 17:19:05 -07:00
Phil Wang
22019fddeb todo 2022-08-31 13:36:05 -07:00
Phil Wang
6fb7e91343 fix ddim to use alpha_cumprod 2022-08-31 07:40:46 -07:00
Phil Wang
ba58ae0bf2 add two asserts to diffusion prior to ensure matching image embedding dimensions for clip, diffusion prior network, and what was set on diffusion prior 2022-08-28 10:11:37 -07:00
Phil Wang
1cc5d0afa7 upgrade to best downsample 2022-08-25 10:37:02 -07:00
Phil Wang
59fa101c4d fix classifier free guidance for diffusion prior, thanks to @jaykim9870 for spotting the issue 2022-08-23 08:29:01 -07:00
Aidan Dempster
916ece164c Merge pull request #234 from Veldrovive/deepspeed_fp16
Fixed issues with clip and deepspeed fp16
2022-08-20 19:01:43 -04:00
Aidan
cbaadb6931 Fixed issues with clip and deepspeed fp16
Also more more general compatibility fixes
2022-08-20 17:58:32 +00:00
Phil Wang
083508ff8e cast attention matrix back to original dtype pre-softmax in attention 2022-08-20 10:56:01 -07:00
Phil Wang
7762edd0ff make it work for @ethancohen123 2022-08-19 11:28:58 -07:00
Phil Wang
de5e628773 cite einops 2022-08-17 08:58:41 -07:00
Phil Wang
1b4046b039 gratitude 2022-08-17 08:57:33 -07:00
Phil Wang
27f19ba7fa make sure diffusion prior trainer can operate with no warmup 2022-08-15 14:27:40 -07:00
Phil Wang
8f38339c2b give diffusion prior trainer cosine annealing lr too 2022-08-15 07:38:01 -07:00
Phil Wang
6b9b4b9e5e add cosine annealing lr schedule 2022-08-15 07:29:56 -07:00
Phil Wang
44e09d5a4d add weight standardization behind feature flag, which may potentially work well with group norm 2022-08-14 11:34:45 -07:00
8 changed files with 341 additions and 94 deletions

View File

@@ -49,6 +49,7 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
- <a href="https://github.com/arogozhnikov">Alex</a> for <a href="https://github.com/arogozhnikov/einops">einops</a>, indispensable tool for tensor manipulation
... and many others. Thank you! 🙏
@@ -633,10 +634,12 @@ Alternatively, you can also use <a href="https://github.com/mlfoundations/open_c
$ pip install open-clip-torch
```
Ex. using the <a href="https://laion.ai/blog/large-openclip/">SOTA Open Clip</a> model trained by <a href="https://github.com/rom1504">Romain</a>
```python
from dalle2_pytorch import OpenClipAdapter
clip = OpenClipAdapter()
clip = OpenClipAdapter('ViT-H/14')
```
Now you'll just have to worry about training the Prior and the Decoder!
@@ -1065,7 +1068,7 @@ dataloader = create_image_embedding_dataloader(
)
for img, emb in dataloader:
print(img.shape) # torch.Size([32, 3, 256, 256])
print(emb.shape) # torch.Size([32, 512])
print(emb["img"].shape) # torch.Size([32, 512])
# Train decoder only as shown above
# Or create a dataset without a loader so you can configure it manually
@@ -1125,6 +1128,7 @@ For detailed information on training the diffusion prior, please refer to the [d
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
- [ ] add simple outpainting, text-guided 2x size the image for starters
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
## Citations
@@ -1264,4 +1268,44 @@ For detailed information on training the diffusion prior, please refer to the [d
}
```
```bibtex
@article{Qiao2019WeightS,
title = {Weight Standardization},
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
journal = {ArXiv},
year = {2019},
volume = {abs/1903.10520}
}
```
```bibtex
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
```
```bibtex
@article{Sunkara2022NoMS,
title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
author = {Raja Sunkara and Tie Luo},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.03641}
}
```
```bibtex
@article{Salimans2022ProgressiveDF,
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
author = {Tim Salimans and Jonathan Ho},
journal = {ArXiv},
year = {2022},
volume = {abs/2202.00512}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,6 +1,6 @@
from dalle2_pytorch.version import __version__
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE

View File

@@ -100,6 +100,9 @@ def eval_decorator(fn):
return out
return inner
def is_float_dtype(dtype):
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
def is_list_str(x):
if not isinstance(x, (list, tuple)):
return False
@@ -250,9 +253,15 @@ class XClipAdapter(BaseClipAdapter):
text = text[..., :self.max_text_len]
text_mask = text != 0
encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
encoder_output_is_cls = encoder_output.ndim == 3
text_cls, text_encodings = (encoder_output[:, 0], encoder_output[:, 1:]) if encoder_output_is_cls else (encoder_output, None)
text_embed = self.clip.to_text_latent(text_cls)
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
if exists(text_encodings):
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return EmbeddedText(l2norm(text_embed), text_encodings)
@torch.no_grad()
@@ -308,7 +317,10 @@ class OpenAIClipAdapter(BaseClipAdapter):
self.eos_id = 49407 # for handling 0 being also '!'
text_attention_final = self.find_layer('ln_final')
self.dim_latent_ = text_attention_final.weight.shape[0]
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1]
self.cleared = False
@@ -327,7 +339,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
@property
def dim_latent(self):
return 512
return self.dim_latent_
@property
def image_size(self):
@@ -348,6 +360,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
text_mask = text_mask & (text != 0)
assert not self.cleared
text_embed = self.clip.encode_text(text)
@@ -377,6 +390,8 @@ class OpenClipAdapter(BaseClipAdapter):
self.eos_id = 49407
text_attention_final = self.find_layer('ln_final')
self._dim_latent = text_attention_final.weight.shape[0]
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1]
self.cleared = False
@@ -396,11 +411,14 @@ class OpenClipAdapter(BaseClipAdapter):
@property
def dim_latent(self):
return 512
return self._dim_latent
@property
def image_size(self):
return self.clip.visual.image_size
image_size = self.clip.visual.image_size
if isinstance(image_size, tuple):
return max(image_size)
return image_size
@property
def image_channels(self):
@@ -417,6 +435,7 @@ class OpenClipAdapter(BaseClipAdapter):
is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
text_mask = text_mask & (text != 0)
assert not self.cleared
text_embed = self.clip.encode_text(text)
@@ -602,7 +621,7 @@ class NoiseScheduler(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def q_sample(self, x_start, t, noise=None):
def q_sample(self, x_start, t, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
@@ -610,6 +629,12 @@ class NoiseScheduler(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def calculate_v(self, x_start, t, noise = None):
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
)
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
shape = x_from.shape
noise = default(noise, lambda: torch.randn_like(x_from))
@@ -621,6 +646,12 @@ class NoiseScheduler(nn.Module):
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
def predict_start_from_v(self, x_t, t, v):
return (
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
@@ -873,6 +904,8 @@ class Attention(nn.Module):
# attention
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.type(sim.dtype)
attn = self.dropout(attn)
# aggregate values
@@ -954,6 +987,8 @@ class DiffusionPriorNetwork(nn.Module):
Rearrange('b (n d) -> b n d', n = num_text_embeds)
)
self.continuous_embedded_time = not exists(num_timesteps)
self.to_time_embeds = nn.Sequential(
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
Rearrange('b (n d) -> b n d', n = num_time_embeds)
@@ -970,7 +1005,10 @@ class DiffusionPriorNetwork(nn.Module):
# dalle1 learned padding strategy
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
self.null_text_encodings = nn.Parameter(torch.randn(1, max_text_len, dim))
self.null_text_embeds = nn.Parameter(torch.randn(1, num_text_embeds, dim))
self.null_image_embed = nn.Parameter(torch.randn(1, dim))
# whether to use self conditioning, Hinton's group's new ddpm technique
@@ -987,7 +1025,7 @@ class DiffusionPriorNetwork(nn.Module):
if cond_scale == 1:
return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1, **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
@@ -998,7 +1036,8 @@ class DiffusionPriorNetwork(nn.Module):
text_embed,
text_encodings = None,
self_cond = None,
cond_drop_prob = 0.
text_cond_drop_prob = 0.,
image_cond_drop_prob = 0.
):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
@@ -1016,6 +1055,14 @@ class DiffusionPriorNetwork(nn.Module):
text_embed = self.to_text_embeds(text_embed)
image_embed = self.to_image_embeds(image_embed)
# classifier free guidance masks
text_keep_mask = prob_mask_like((batch,), 1 - text_cond_drop_prob, device = device)
text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
image_keep_mask = prob_mask_like((batch,), 1 - image_cond_drop_prob, device = device)
image_keep_mask = rearrange(image_keep_mask, 'b -> b 1 1')
# make text encodings optional
# although the paper seems to suggest it is present <--
@@ -1036,31 +1083,41 @@ class DiffusionPriorNetwork(nn.Module):
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
mask = F.pad(mask, (0, remainder), value = False)
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
# mask out text encodings with null encodings
null_text_encodings = self.null_text_encodings.to(text_encodings.dtype)
text_encodings = torch.where(
rearrange(mask, 'b n -> b n 1').clone(),
rearrange(mask, 'b n -> b n 1').clone() & text_keep_mask,
text_encodings,
null_text_encodings
)
# mask out text embeddings with null text embeddings
null_text_embeds = self.null_text_embeds.to(text_embed.dtype)
text_embed = torch.where(
text_keep_mask,
text_embed,
null_text_embeds
)
# classifier free guidance
# mask out image embeddings with null image embeddings
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
keep_mask = rearrange(keep_mask, 'b -> b 1')
null_image_embed = self.null_image_embed.to(image_embed.dtype)
mask &= keep_mask
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
keep_mask = repeat(keep_mask, 'b 1 -> b n', n = num_text_embeds)
mask = torch.cat((mask, keep_mask), dim = 1)
image_embed = torch.where(
image_keep_mask,
image_embed,
null_image_embed
)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
if self.continuous_embedded_time:
diffusion_timesteps = diffusion_timesteps.type(dtype)
time_embed = self.to_time_embeds(diffusion_timesteps)
@@ -1099,8 +1156,11 @@ class DiffusionPrior(nn.Module):
timesteps = 1000,
sample_timesteps = None,
cond_drop_prob = 0.,
text_cond_drop_prob = None,
image_cond_drop_prob = None,
loss_type = "l2",
predict_x_start = True,
predict_v = False,
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
@@ -1137,15 +1197,22 @@ class DiffusionPrior(nn.Module):
self.net = net
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
assert net.dim == self.image_embed_dim, f'your diffusion prior network has a dimension of {net.dim}, but you set your image embedding dimension (keyword image_embed_dim) on DiffusionPrior to {self.image_embed_dim}'
assert not exists(clip) or clip.dim_latent == self.image_embed_dim, f'you passed in a CLIP to the diffusion prior with latent dimensions of {clip.dim_latent}, but your image embedding dimension (keyword image_embed_dim) for the DiffusionPrior was set to {self.image_embed_dim}'
self.channels = default(image_channels, lambda: clip.image_channels)
self.cond_drop_prob = cond_drop_prob
self.can_classifier_guidance = cond_drop_prob > 0.
self.text_cond_drop_prob = default(text_cond_drop_prob, cond_drop_prob)
self.image_cond_drop_prob = default(image_cond_drop_prob, cond_drop_prob)
self.can_classifier_guidance = self.text_cond_drop_prob > 0. and self.image_cond_drop_prob > 0.
self.condition_on_text_encodings = condition_on_text_encodings
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
self.predict_x_start = predict_x_start
self.predict_v = predict_v # takes precedence over predict_x_start
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
@@ -1175,7 +1242,9 @@ class DiffusionPrior(nn.Module):
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
if self.predict_x_start:
if self.predict_v:
x_start = self.noise_scheduler.predict_start_from_v(x, t = t, v = pred)
elif self.predict_x_start:
x_start = pred
else:
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
@@ -1224,7 +1293,7 @@ class DiffusionPrior(nn.Module):
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
times = torch.linspace(-1., total_timesteps, steps = timesteps + 1)[:-1]
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
@@ -1246,12 +1315,16 @@ class DiffusionPrior(nn.Module):
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
if self.predict_x_start:
# derive x0
if self.predict_v:
x_start = self.noise_scheduler.predict_start_from_v(image_embed, t = time_cond, v = pred)
elif self.predict_x_start:
x_start = pred
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred)
else:
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
pred_noise = pred
# clip x0 before maybe predicting noise
if not self.predict_x_start:
x_start.clamp_(-1., 1.)
@@ -1259,6 +1332,17 @@ class DiffusionPrior(nn.Module):
if self.predict_x_start and self.sampling_clamp_l2norm:
x_start = self.l2norm_clamp_embed(x_start)
# predict noise
if self.predict_x_start or self.predict_v:
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
else:
pred_noise = pred
if time_next < 0:
image_embed = x_start
continue
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
@@ -1300,14 +1384,20 @@ class DiffusionPrior(nn.Module):
image_embed_noisy,
times,
self_cond = self_cond,
cond_drop_prob = self.cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob,
image_cond_drop_prob = self.image_cond_drop_prob,
**text_cond
)
if self.predict_x_start and self.training_clamp_l2norm:
pred = self.l2norm_clamp_embed(pred)
target = noise if not self.predict_x_start else image_embed
if self.predict_v:
target = self.noise_scheduler.calculate_v(image_embed, times, noise)
elif self.predict_x_start:
target = image_embed
else:
target = noise
loss = self.noise_scheduler.loss_fn(pred, target)
return loss
@@ -1377,7 +1467,7 @@ class DiffusionPrior(nn.Module):
**kwargs
):
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
if exists(image):
@@ -1447,9 +1537,34 @@ class PixelShuffleUpsample(nn.Module):
def forward(self, x):
return self.net(x)
def Downsample(dim, *, dim_out = None):
def Downsample(dim, dim_out = None):
# https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample
# named SP-conv in the paper, but basically a pixel unshuffle
dim_out = default(dim_out, dim)
return nn.Conv2d(dim, dim_out, 4, 2, 1)
return nn.Sequential(
Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
nn.Conv2d(dim * 4, dim_out, 1)
)
class WeightStandardizedConv2d(nn.Conv2d):
"""
https://arxiv.org/abs/1903.10520
weight standardization purportedly works synergistically with group normalization
"""
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
weight = self.weight
flattened_weights = rearrange(weight, 'o ... -> o (...)')
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
var = torch.var(flattened_weights, dim = -1, unbiased = False)
var = rearrange(var, 'o -> o 1 1 1')
weight = (weight - mean) * (var + eps).rsqrt()
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
@@ -1458,6 +1573,8 @@ class SinusoidalPosEmb(nn.Module):
def forward(self, x):
dtype, device = x.dtype, x.device
assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
@@ -1469,10 +1586,13 @@ class Block(nn.Module):
self,
dim,
dim_out,
groups = 8
groups = 8,
weight_standardization = False
):
super().__init__()
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
conv_klass = nn.Conv2d if not weight_standardization else WeightStandardizedConv2d
self.project = conv_klass(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
@@ -1496,6 +1616,7 @@ class ResnetBlock(nn.Module):
cond_dim = None,
time_cond_dim = None,
groups = 8,
weight_standardization = False,
cosine_sim_cross_attn = False
):
super().__init__()
@@ -1521,8 +1642,8 @@ class ResnetBlock(nn.Module):
)
)
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardization = weight_standardization)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb = None, cond = None):
@@ -1607,6 +1728,7 @@ class CrossAttention(nn.Module):
sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = attn.type(sim.dtype)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
@@ -1747,6 +1869,7 @@ class Unet(nn.Module):
init_dim = None,
init_conv_kernel_size = 7,
resnet_groups = 8,
resnet_weight_standardization = False,
num_resnet_blocks = 2,
init_cross_embed = True,
init_cross_embed_kernel_sizes = (3, 7, 15),
@@ -1894,7 +2017,7 @@ class Unet(nn.Module):
# prepare resnet klass
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn)
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn, weight_standardization = resnet_weight_standardization)
# give memory efficient unet an initial resnet block
@@ -2350,6 +2473,7 @@ class Decoder(nn.Module):
loss_type = 'l2',
beta_schedule = None,
predict_x_start = False,
predict_v = False,
predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
@@ -2522,6 +2646,10 @@ class Decoder(nn.Module):
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
# predict v
self.predict_v = cast_tuple(predict_v, len(unets))
# input image range
self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
@@ -2633,14 +2761,16 @@ class Decoder(nn.Module):
x = x.clamp(-s, s) / s
return x
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, predict_v = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
if predict_x_start:
if predict_v:
x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
elif predict_x_start:
x_start = pred
else:
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
@@ -2667,9 +2797,9 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, predict_v = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
@@ -2684,6 +2814,7 @@ class Decoder(nn.Module):
image_embed,
noise_scheduler,
predict_x_start = False,
predict_v = False,
learned_variance = False,
clip_denoised = True,
lowres_cond_img = None,
@@ -2742,6 +2873,7 @@ class Decoder(nn.Module):
lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
predict_x_start = predict_x_start,
predict_v = predict_v,
noise_scheduler = noise_scheduler,
learned_variance = learned_variance,
clip_denoised = clip_denoised
@@ -2767,6 +2899,7 @@ class Decoder(nn.Module):
timesteps,
eta = 1.,
predict_x_start = False,
predict_v = False,
learned_variance = False,
clip_denoised = True,
lowres_cond_img = None,
@@ -2778,12 +2911,13 @@ class Decoder(nn.Module):
inpaint_mask = None,
inpaint_resample_times = 5
):
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod, self.ddim_sampling_eta
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
time_pairs = list(filter(lambda t: t[0] > t[1], time_pairs))
is_inpaint = exists(inpaint_image)
resample_times = inpaint_resample_times if is_inpaint else 1
@@ -2825,16 +2959,27 @@ class Decoder(nn.Module):
pred, _ = self.parse_unet_output(learned_variance, unet_output)
if predict_x_start:
# predict x0
if predict_v:
x_start = noise_scheduler.predict_start_from_v(img, t = time_cond, v = pred)
elif predict_x_start:
x_start = pred
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
pred_noise = pred
# maybe clip x0
if clip_denoised:
x_start = self.dynamic_threshold(x_start)
# predict noise
if predict_x_start or predict_v:
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
else:
pred_noise = pred
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if not is_last_timestep else 0.
@@ -2867,7 +3012,7 @@ class Decoder(nn.Module):
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, predict_v = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
noise = default(noise, lambda: torch.randn_like(x_start))
# normalize to [-1, 1]
@@ -2912,7 +3057,12 @@ class Decoder(nn.Module):
pred, _ = self.parse_unet_output(learned_variance, unet_output)
target = noise if not predict_x_start else x_start
if predict_v:
target = noise_scheduler.calculate_v(x_start, times, noise)
elif predict_x_start:
target = x_start
else:
target = noise
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b (...)', 'mean')
@@ -2998,7 +3148,7 @@ class Decoder(nn.Module):
num_unets = self.num_unets
cond_scale = cast_tuple(cond_scale, num_unets)
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
for unet_number, unet, vae, channel, image_size, predict_x_start, predict_v, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.predict_v, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
if unet_number < start_at_unet_number:
continue # It's the easiest way to do it
@@ -3034,6 +3184,7 @@ class Decoder(nn.Module):
text_encodings = text_encodings,
cond_scale = unet_cond_scale,
predict_x_start = predict_x_start,
predict_v = predict_v,
learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img,
@@ -3073,6 +3224,7 @@ class Decoder(nn.Module):
lowres_conditioner = self.lowres_conds[unet_index]
target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index]
predict_v = self.predict_v[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
learned_variance = self.learned_variance[unet_index]
b, c, h, w, device, = *image.shape, image.device
@@ -3111,7 +3263,7 @@ class Decoder(nn.Module):
image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
if not return_lowres_cond_image:
return losses

View File

@@ -4,11 +4,13 @@ from pydantic import BaseModel, validator, root_validator
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
from x_clip import CLIP as XCLIP
from open_clip import list_pretrained
from coca_pytorch import CoCa
from dalle2_pytorch.dalle2_pytorch import (
CoCaAdapter,
OpenAIClipAdapter,
OpenClipAdapter,
Unet,
Decoder,
DiffusionPrior,
@@ -117,6 +119,10 @@ class AdapterConfig(BaseModel):
def create(self):
if self.make == "openai":
return OpenAIClipAdapter(self.model)
elif self.make == "open_clip":
pretrained = dict(list_pretrained())
checkpoint = pretrained[self.model]
return OpenClipAdapter(name=self.model, pretrained=checkpoint)
elif self.make == "x-clip":
return XClipAdapter(XCLIP(**self.base_model_kwargs))
elif self.make == "coca":
@@ -241,7 +247,7 @@ class DecoderConfig(BaseModel):
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable[int]] = None
sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None
loss_type: str = 'l2'
beta_schedule: ListOrTuple[str] = None # None means all cosine
learned_variance: SingularOrIterable[bool] = True
@@ -307,6 +313,7 @@ class DecoderTrainConfig(BaseModel):
wd: SingularOrIterable[float] = 0.01
warmup_steps: Optional[SingularOrIterable[int]] = None
find_unused_parameters: bool = True
static_graph: bool = True
max_grad_norm: SingularOrIterable[float] = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset

View File

@@ -9,7 +9,7 @@ from collections.abc import Iterable
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
@@ -181,7 +181,8 @@ class DiffusionPriorTrainer(nn.Module):
eps = 1e-6,
max_grad_norm = None,
group_wd_params = True,
warmup_steps = 1,
warmup_steps = None,
cosine_decay_max_steps = None,
**kwargs
):
super().__init__()
@@ -233,8 +234,11 @@ class DiffusionPriorTrainer(nn.Module):
**self.optim_kwargs,
**kwargs
)
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
if exists(cosine_decay_max_steps):
self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps)
else:
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
@@ -271,6 +275,7 @@ class DiffusionPriorTrainer(nn.Module):
# FIXME: LambdaLR can't be saved due to pickling issues
save_obj = dict(
optimizer = self.optimizer.state_dict(),
scheduler = self.scheduler.state_dict(),
warmup_scheduler = self.warmup_scheduler,
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
version = version.parse(__version__),
@@ -317,7 +322,9 @@ class DiffusionPriorTrainer(nn.Module):
# unwrap the model when loading from checkpoint
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
self.optimizer.load_state_dict(loaded_obj['optimizer'])
self.scheduler.load_state_dict(loaded_obj['scheduler'])
# set warmupstep
if exists(self.warmup_scheduler):
@@ -350,7 +357,8 @@ class DiffusionPriorTrainer(nn.Module):
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
if not self.accelerator.optimizer_step_was_skipped:
with self.warmup_scheduler.dampening():
sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
with sched_context():
self.scheduler.step()
if self.use_ema:
@@ -433,6 +441,7 @@ class DecoderTrainer(nn.Module):
wd = 1e-2,
eps = 1e-8,
warmup_steps = None,
cosine_decay_max_steps = None,
max_grad_norm = 0.5,
amp = False,
group_wd_params = True,
@@ -454,7 +463,7 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay
# per unet
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
@@ -462,7 +471,7 @@ class DecoderTrainer(nn.Module):
schedulers = []
warmup_schedulers = []
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
if isinstance(unet, nn.Identity):
optimizers.append(None)
schedulers.append(None)
@@ -478,7 +487,11 @@ class DecoderTrainer(nn.Module):
)
optimizers.append(optimizer)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
if exists(unet_cosine_decay_max_steps):
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
else:
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler)
@@ -558,9 +571,15 @@ class DecoderTrainer(nn.Module):
for ind in range(0, self.num_unets):
optimizer_key = f'optim{ind}'
scheduler_key = f'sched{ind}'
optimizer = getattr(self, optimizer_key)
state_dict = optimizer.state_dict() if optimizer is not None else None
save_obj = {**save_obj, optimizer_key: state_dict}
scheduler = getattr(self, scheduler_key)
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
@@ -581,10 +600,18 @@ class DecoderTrainer(nn.Module):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
scheduler_key = f'sched{ind}'
scheduler = getattr(self, scheduler_key)
warmup_scheduler = self.warmup_schedulers[ind]
if optimizer is not None:
if exists(optimizer):
optimizer.load_state_dict(loaded_obj[optimizer_key])
if exists(scheduler):
scheduler.load_state_dict(loaded_obj[scheduler_key])
if exists(warmup_scheduler):
warmup_scheduler.last_step = last_step

View File

@@ -1 +1 @@
__version__ = '1.6.5'
__version__ = '1.11.4'

View File

@@ -26,6 +26,7 @@ setup(
install_requires=[
'accelerate',
'click',
'open-clip-torch>=2.0.0,<3.0.0',
'clip-anytorch>=2.4.0',
'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7',

View File

@@ -134,7 +134,7 @@ def get_example_data(dataloader, device, n=5):
break
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
"""
Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions
@@ -144,7 +144,9 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
if img_embeddings[0] is None:
# Generate image embeddings from clip
imgs_tensor = torch.stack(real_images)
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
assert clip is not None, "clip is None, but img_embeddings is None"
imgs_tensor.to(device=device)
img_embeddings, img_encoding = clip.embed_image(imgs_tensor)
sample_params["image_embed"] = img_embeddings
else:
# Then we are using precomputed image embeddings
@@ -153,8 +155,10 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
if condition_on_text_encodings:
if text_embeddings[0] is None:
# Generate text embeddings from text
tokenized_texts = tokenize(txts, truncate=True)
sample_params["text"] = tokenized_texts
assert clip is not None, "clip is None, but text_embeddings is None"
tokenized_texts = tokenize(txts, truncate=True).to(device=device)
text_embed, text_encodings = clip.embed_text(tokenized_texts)
sample_params["text_encodings"] = text_encodings
else:
# Then we are using precomputed text embeddings
text_embeddings = torch.stack(text_embeddings)
@@ -166,7 +170,7 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
sample_params["image"] = torch.stack(real_images)
if device is not None:
sample_params["_device"] = device
samples = trainer.sample(**sample_params)
samples = trainer.sample(**sample_params, _cast_deepspeed_precision=False) # At sampling time we don't want to cast to FP16
generated_images = list(samples)
captions = [text_prepend + txt for txt in txts]
if match_image_size:
@@ -174,15 +178,15 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
def generate_grid_samples(trainer, examples, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
"""
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=None, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
"""
Computes evaluation metrics for the decoder
"""
@@ -192,7 +196,7 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, conditi
if len(examples) == 0:
print("No data to evaluate. Check that your dataloader has shards.")
return metrics
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
@@ -225,8 +229,8 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, conditi
metrics["KID_std"] = kid_std.item()
if exists(LPIPS):
# Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1)
renorm_generated_images = generated_images.mul(2).sub(1)
renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1)
renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1)
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
lpips.to(device=device)
lpips.update(renorm_real_images, renorm_generated_images)
@@ -265,6 +269,7 @@ def train(
accelerator: Accelerator,
tracker: Tracker,
inference_device,
clip=None,
evaluate_config=None,
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
validation_samples = None,
@@ -371,15 +376,19 @@ def train(
forward_params['image_embed'] = img_emb
else:
# Forward pass automatically generates embedding
pass
assert clip is not None
img_embed, img_encoding = clip.embed_image(img)
forward_params['image_embed'] = img_embed
if condition_on_text_encodings:
if has_text_embedding:
forward_params['text_encodings'] = text_emb
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
assert clip is not None
tokenized_texts = tokenize(txt, truncate=True).to(inference_device)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts
text_embed, text_encodings = clip.embed_text(tokenized_texts)
forward_params['text_encodings'] = text_encodings
loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)
trainer.update(unet_number=unet)
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
@@ -419,7 +428,7 @@ def train(
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
if epoch_samples is not None and sample >= epoch_samples:
@@ -462,15 +471,19 @@ def train(
forward_params['image_embed'] = img_emb.float()
else:
# Forward pass automatically generates embedding
pass
assert clip is not None
img_embed, img_encoding = clip.embed_image(img)
forward_params['image_embed'] = img_embed
if condition_on_text_encodings:
if has_text_embedding:
forward_params['text_encodings'] = text_emb.float()
else:
# Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True)
assert clip is not None
tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts
text_embed, text_encodings = clip.embed_text(tokenized_texts)
forward_params['text_encodings'] = text_encodings
loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)
average_val_loss_tensor[0, unet-1] += loss
@@ -498,7 +511,7 @@ def train(
if next_task == 'eval':
if exists(evaluate_config):
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, clip=clip, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
if is_master:
tracker.log(evaluation, step=step())
next_task = 'sample'
@@ -509,8 +522,8 @@ def train(
# Generate examples and save the model if we are the master
# Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
test_images, test_captions = generate_grid_samples(trainer, test_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
@@ -532,6 +545,7 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
"NumProcesses": accelerator.num_processes,
"MixedPrecision": accelerator.mixed_precision
}
accelerator.wait_for_everyone() # If nodes arrive at this point at different times they might try to autoresume the current run which makes no sense and will cause errors
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
tracker.save_config(config_path, config_name='decoder_config.json')
tracker.add_save_metadata(state_dict_key='config', metadata=config.dict())
@@ -542,7 +556,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
torch.manual_seed(config.seed)
# Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph)
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
@@ -555,10 +569,6 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
# This is an invalid configuration until we figure out how to handle this
raise ValueError("DeepSpeed does not support multi-node distributed training")
# Set up data
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
@@ -579,6 +589,11 @@ def initialize_training(config: TrainDecoderConfig, config_path):
seed = config.seed,
)
# If clip is in the model, we need to remove it for compatibility with deepspeed
clip = None
if config.decoder.clip is not None:
clip = config.decoder.clip.create() # Of course we keep it to use it during training, just not in the decoder as that causes issues
config.decoder.clip = None
# Create the decoder model and print basic info
decoder = config.decoder.create()
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
@@ -590,7 +605,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
has_text_embeddings = config.data.text_embeddings_url is not None
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
has_clip_model = config.decoder.clip is not None
has_clip_model = clip is not None
data_source_string = ""
if has_img_embeddings:
@@ -615,6 +630,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
train(dataloaders, decoder, accelerator,
clip=clip,
tracker=tracker,
inference_device=accelerator.device,
evaluate_config=config.evaluate,