Compare commits

..

25 Commits

Author SHA1 Message Date
Phil Wang
1cc5d0afa7 upgrade to best downsample 2022-08-25 10:37:02 -07:00
Phil Wang
59fa101c4d fix classifier free guidance for diffusion prior, thanks to @jaykim9870 for spotting the issue 2022-08-23 08:29:01 -07:00
Aidan Dempster
916ece164c Merge pull request #234 from Veldrovive/deepspeed_fp16
Fixed issues with clip and deepspeed fp16
2022-08-20 19:01:43 -04:00
Aidan
cbaadb6931 Fixed issues with clip and deepspeed fp16
Also more more general compatibility fixes
2022-08-20 17:58:32 +00:00
Phil Wang
083508ff8e cast attention matrix back to original dtype pre-softmax in attention 2022-08-20 10:56:01 -07:00
Phil Wang
7762edd0ff make it work for @ethancohen123 2022-08-19 11:28:58 -07:00
Phil Wang
de5e628773 cite einops 2022-08-17 08:58:41 -07:00
Phil Wang
1b4046b039 gratitude 2022-08-17 08:57:33 -07:00
Phil Wang
27f19ba7fa make sure diffusion prior trainer can operate with no warmup 2022-08-15 14:27:40 -07:00
Phil Wang
8f38339c2b give diffusion prior trainer cosine annealing lr too 2022-08-15 07:38:01 -07:00
Phil Wang
6b9b4b9e5e add cosine annealing lr schedule 2022-08-15 07:29:56 -07:00
Phil Wang
44e09d5a4d add weight standardization behind feature flag, which may potentially work well with group norm 2022-08-14 11:34:45 -07:00
Phil Wang
34806663e3 make it so diffusion prior p_sample_loop returns unnormalized image embeddings 2022-08-13 10:03:40 -07:00
Phil Wang
dc816b1b6e dry up some code around handling unet outputs with learned variance 2022-08-12 15:25:03 -07:00
Phil Wang
05192ffac4 fix self conditioning shape in diffusion prior 2022-08-12 12:30:03 -07:00
Phil Wang
9440411954 make self conditioning technique work with diffusion prior 2022-08-12 12:20:51 -07:00
Phil Wang
981d407792 comment 2022-08-12 11:41:23 -07:00
Phil Wang
7c5477b26d bet on the new self-conditioning technique out of geoffrey hintons group 2022-08-12 11:36:08 -07:00
Phil Wang
be3bb868bf add gradient checkpointing for all resnet blocks 2022-08-02 19:21:44 -07:00
Phil Wang
451de34871 enforce clip anytorch version 2022-07-30 10:07:55 -07:00
Phil Wang
f22e8c8741 make open clip available for use with dalle2 pytorch 2022-07-30 09:02:31 -07:00
Phil Wang
87432e93ad quick fix for linear attention 2022-07-29 13:17:12 -07:00
Phil Wang
d167378401 add cosine sim for self attention as well, as a setting 2022-07-29 12:48:20 -07:00
Phil Wang
2d67d5821e change up epsilon in layernorm the case of using fp16, thanks to @Veldrovive for figuring out this stabilizes training 2022-07-29 12:41:02 -07:00
Phil Wang
748c7fe7af allow for cosine sim cross attention, modify linear attention in attempt to resolve issue on fp16 2022-07-29 11:12:18 -07:00
7 changed files with 527 additions and 147 deletions

View File

@@ -49,6 +49,7 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice - <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship - <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library - <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
- <a href="https://github.com/arogozhnikov">Alex</a> for <a href="https://github.com/arogozhnikov/einops">einops</a>, indispensable tool for tensor manipulation
... and many others. Thank you! 🙏 ... and many others. Thank you! 🙏
@@ -627,6 +628,18 @@ images = dalle2(
# save your image (in this example, of size 256x256) # save your image (in this example, of size 256x256)
``` ```
Alternatively, you can also use <a href="https://github.com/mlfoundations/open_clip">Open Clip</a>
```bash
$ pip install open-clip-torch
```
```python
from dalle2_pytorch import OpenClipAdapter
clip = OpenClipAdapter()
```
Now you'll just have to worry about training the Prior and the Decoder! Now you'll just have to worry about training the Prior and the Decoder!
## Inpainting ## Inpainting
@@ -1241,4 +1254,45 @@ For detailed information on training the diffusion prior, please refer to the [d
} }
``` ```
```bibtex
@misc{chen2022analog,
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
year = {2022},
eprint = {2208.04202},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@article{Qiao2019WeightS,
title = {Weight Standardization},
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
journal = {ArXiv},
year = {2019},
volume = {abs/1903.10520}
}
```
```bibtex
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
```
```bibtex
@article{Sunkara2022NoMS,
title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
author = {Raja Sunkara and Tie Luo},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.03641}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -8,6 +8,7 @@ from pathlib import Path
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torch import nn, einsum from torch import nn, einsum
import torchvision.transforms as T import torchvision.transforms as T
@@ -37,6 +38,8 @@ from coca_pytorch import CoCa
NAT = 1. / math.log(2.) NAT = 1. / math.log(2.)
UnetOutput = namedtuple('UnetOutput', ['pred', 'var_interp_frac_unnormalized'])
# helper functions # helper functions
def exists(val): def exists(val):
@@ -108,6 +111,28 @@ def pad_tuple_to_length(t, length, fillvalue = None):
return t return t
return (*t, *((fillvalue,) * remain_length)) return (*t, *((fillvalue,) * remain_length))
# checkpointing helper function
def make_checkpointable(fn, **kwargs):
if isinstance(fn, nn.ModuleList):
return [maybe(make_checkpointable)(el, **kwargs) for el in fn]
condition = kwargs.pop('condition', None)
if exists(condition) and not condition(fn):
return fn
@wraps(fn)
def inner(*args):
input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args])
if not input_needs_grad:
return fn(*args)
return checkpoint(fn, *args)
return inner
# for controlling freezing of CLIP # for controlling freezing of CLIP
def set_module_requires_grad_(module, requires_grad): def set_module_requires_grad_(module, requires_grad):
@@ -225,9 +250,15 @@ class XClipAdapter(BaseClipAdapter):
text = text[..., :self.max_text_len] text = text[..., :self.max_text_len]
text_mask = text != 0 text_mask = text != 0
encoder_output = self.clip.text_transformer(text) encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
encoder_output_is_cls = encoder_output.ndim == 3
text_cls, text_encodings = (encoder_output[:, 0], encoder_output[:, 1:]) if encoder_output_is_cls else (encoder_output, None)
text_embed = self.clip.to_text_latent(text_cls) text_embed = self.clip.to_text_latent(text_cls)
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
if exists(text_encodings):
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return EmbeddedText(l2norm(text_embed), text_encodings) return EmbeddedText(l2norm(text_embed), text_encodings)
@torch.no_grad() @torch.no_grad()
@@ -339,6 +370,75 @@ class OpenAIClipAdapter(BaseClipAdapter):
image_embed = self.clip.encode_image(image) image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None) return EmbeddedImage(l2norm(image_embed.float()), None)
class OpenClipAdapter(BaseClipAdapter):
def __init__(
self,
name = 'ViT-B/32',
pretrained = 'laion400m_e32'
):
import open_clip
clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)
super().__init__(clip)
self.eos_id = 49407
text_attention_final = self.find_layer('ln_final')
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1]
self.cleared = False
def find_layer(self, layer):
modules = dict([*self.clip.named_modules()])
return modules.get(layer, None)
def clear(self):
if self.cleared:
return
self.handle()
def _hook(self, _, inputs, outputs):
self.text_encodings = outputs
@property
def dim_latent(self):
return 512
@property
def image_size(self):
return self.clip.visual.image_size
@property
def image_channels(self):
return 3
@property
def max_text_len(self):
return self.clip.context_length
@torch.no_grad()
def embed_text(self, text):
text = text[..., :self.max_text_len]
is_eos_id = (text == self.eos_id)
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
assert not self.cleared
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
del self.text_encodings
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())
@torch.no_grad()
def embed_image(self, image):
assert not self.cleared
image = self.validate_and_resize_image(image)
image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None)
# classifier free guidance functions # classifier free guidance functions
def prob_mask_like(shape, prob, device): def prob_mask_like(shape, prob, device):
@@ -547,34 +647,40 @@ class NoiseScheduler(nn.Module):
# diffusion prior # diffusion prior
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5, stable = False): def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.fp16_eps = fp16_eps
self.stable = stable self.stable = stable
self.g = nn.Parameter(torch.ones(dim)) self.g = nn.Parameter(torch.ones(dim))
def forward(self, x): def forward(self, x):
eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
if self.stable: if self.stable:
x = x / x.amax(dim = -1, keepdim = True).detach() x = x / x.amax(dim = -1, keepdim = True).detach()
var = torch.var(x, dim = -1, unbiased = False, keepdim = True) var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True) mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g return (x - mean) * (var + eps).rsqrt() * self.g
class ChanLayerNorm(nn.Module): class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5, stable = False): def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.fp16_eps = fp16_eps
self.stable = stable self.stable = stable
self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x): def forward(self, x):
eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
if self.stable: if self.stable:
x = x / x.amax(dim = 1, keepdim = True).detach() x = x / x.amax(dim = 1, keepdim = True).detach()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True) var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True) mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g return (x - mean) * (var + eps).rsqrt() * self.g
class Residual(nn.Module): class Residual(nn.Module):
def __init__(self, fn): def __init__(self, fn):
@@ -695,11 +801,12 @@ class Attention(nn.Module):
dropout = 0., dropout = 0.,
causal = False, causal = False,
rotary_emb = None, rotary_emb = None,
pb_relax_alpha = 128 cosine_sim = True,
cosine_sim_scale = 16
): ):
super().__init__() super().__init__()
self.pb_relax_alpha = pb_relax_alpha self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1) self.cosine_sim = cosine_sim
self.heads = heads self.heads = heads
inner_dim = dim_head * heads inner_dim = dim_head * heads
@@ -739,6 +846,13 @@ class Attention(nn.Module):
k = torch.cat((nk, k), dim = -2) k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2) v = torch.cat((nv, v), dim = -2)
# whether to use cosine sim
if self.cosine_sim:
q, k = map(l2norm, (q, k))
q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
# calculate query / key similarities # calculate query / key similarities
sim = einsum('b h i d, b j d -> b h i j', q, k) sim = einsum('b h i d, b j d -> b h i j', q, k)
@@ -764,10 +878,9 @@ class Attention(nn.Module):
# attention # attention
sim = sim - sim.amax(dim = -1, keepdim = True).detach() attn = sim.softmax(dim = -1, dtype = torch.float32)
sim = sim * self.pb_relax_alpha attn = attn.type(sim.dtype)
attn = sim.softmax(dim = -1)
attn = self.dropout(attn) attn = self.dropout(attn)
# aggregate values # aggregate values
@@ -834,9 +947,12 @@ class DiffusionPriorNetwork(nn.Module):
num_image_embeds = 1, num_image_embeds = 1,
num_text_embeds = 1, num_text_embeds = 1,
max_text_len = 256, max_text_len = 256,
self_cond = False,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
self.dim = dim
self.num_time_embeds = num_time_embeds self.num_time_embeds = num_time_embeds
self.num_image_embeds = num_image_embeds self.num_image_embeds = num_image_embeds
self.num_text_embeds = num_text_embeds self.num_text_embeds = num_text_embeds
@@ -862,7 +978,14 @@ class DiffusionPriorNetwork(nn.Module):
# dalle1 learned padding strategy # dalle1 learned padding strategy
self.max_text_len = max_text_len self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
self.null_text_encodings = nn.Parameter(torch.randn(1, max_text_len, dim))
self.null_text_embeds = nn.Parameter(torch.randn(1, num_text_embeds, dim))
self.null_image_embed = nn.Parameter(torch.randn(1, dim))
# whether to use self conditioning, Hinton's group's new ddpm technique
self.self_cond = self_cond
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
@@ -875,7 +998,7 @@ class DiffusionPriorNetwork(nn.Module):
if cond_scale == 1: if cond_scale == 1:
return logits return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1, **kwargs)
return null_logits + (logits - null_logits) * cond_scale return null_logits + (logits - null_logits) * cond_scale
def forward( def forward(
@@ -885,18 +1008,34 @@ class DiffusionPriorNetwork(nn.Module):
*, *,
text_embed, text_embed,
text_encodings = None, text_encodings = None,
cond_drop_prob = 0. self_cond = None,
text_cond_drop_prob = 0.,
image_cond_drop_prob = 0.
): ):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
# setup self conditioning
if self.self_cond:
self_cond = default(self_cond, lambda: torch.zeros(batch, self.dim, device = device, dtype = dtype))
self_cond = rearrange(self_cond, 'b d -> b 1 d')
# in section 2.2, last paragraph # in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction" # "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
text_embed = self.to_text_embeds(text_embed) text_embed = self.to_text_embeds(text_embed)
image_embed = self.to_image_embeds(image_embed) image_embed = self.to_image_embeds(image_embed)
# classifier free guidance masks
text_keep_mask = prob_mask_like((batch,), 1 - text_cond_drop_prob, device = device)
text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
image_keep_mask = prob_mask_like((batch,), 1 - image_cond_drop_prob, device = device)
image_keep_mask = rearrange(image_keep_mask, 'b -> b 1 1')
# make text encodings optional # make text encodings optional
# although the paper seems to suggest it is present <-- # although the paper seems to suggest it is present <--
@@ -917,36 +1056,46 @@ class DiffusionPriorNetwork(nn.Module):
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.) text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
mask = F.pad(mask, (0, remainder), value = False) mask = F.pad(mask, (0, remainder), value = False)
null_text_embeds = self.null_text_embed.to(text_encodings.dtype) # mask out text encodings with null encodings
null_text_encodings = self.null_text_encodings.to(text_encodings.dtype)
text_encodings = torch.where( text_encodings = torch.where(
rearrange(mask, 'b n -> b n 1').clone(), rearrange(mask, 'b n -> b n 1').clone() & text_keep_mask,
text_encodings, text_encodings,
null_text_encodings
)
# mask out text embeddings with null text embeddings
null_text_embeds = self.null_text_embeds.to(text_embed.dtype)
text_embeds = torch.where(
text_keep_mask,
text_embed,
null_text_embeds null_text_embeds
) )
# classifier free guidance # mask out image embeddings with null image embeddings
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device) null_image_embed = self.null_image_embed.to(image_embed.dtype)
keep_mask = rearrange(keep_mask, 'b -> b 1')
mask &= keep_mask image_embed = torch.where(
image_keep_mask,
# whether text embedding is masked or not depends on the classifier free guidance conditional masking image_embed,
null_image_embed
keep_mask = repeat(keep_mask, 'b 1 -> b n', n = num_text_embeds) )
mask = torch.cat((mask, keep_mask), dim = 1)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right # but let's just do it right
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.to_time_embeds(diffusion_timesteps) time_embed = self.to_time_embeds(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
if self.self_cond:
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
tokens = torch.cat(( tokens = torch.cat((
text_encodings, text_encodings,
text_embed, text_embed,
@@ -977,6 +1126,8 @@ class DiffusionPrior(nn.Module):
timesteps = 1000, timesteps = 1000,
sample_timesteps = None, sample_timesteps = None,
cond_drop_prob = 0., cond_drop_prob = 0.,
text_cond_drop_prob = None,
image_cond_drop_prob = None,
loss_type = "l2", loss_type = "l2",
predict_x_start = True, predict_x_start = True,
beta_schedule = "cosine", beta_schedule = "cosine",
@@ -1017,8 +1168,10 @@ class DiffusionPrior(nn.Module):
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent) self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
self.channels = default(image_channels, lambda: clip.image_channels) self.channels = default(image_channels, lambda: clip.image_channels)
self.cond_drop_prob = cond_drop_prob self.text_cond_drop_prob = default(text_cond_drop_prob, cond_drop_prob)
self.can_classifier_guidance = cond_drop_prob > 0. self.image_cond_drop_prob = default(image_cond_drop_prob, cond_drop_prob)
self.can_classifier_guidance = self.text_cond_drop_prob > 0. and self.image_cond_drop_prob > 0.
self.condition_on_text_encodings = condition_on_text_encodings self.condition_on_text_encodings = condition_on_text_encodings
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
@@ -1048,45 +1201,50 @@ class DiffusionPrior(nn.Module):
def l2norm_clamp_embed(self, image_embed): def l2norm_clamp_embed(self, image_embed):
return l2norm(image_embed) * self.image_embed_scale return l2norm(image_embed) * self.image_embed_scale
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.): def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond) pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
if self.predict_x_start: if self.predict_x_start:
x_recon = pred x_start = pred
else: else:
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised and not self.predict_x_start: if clip_denoised and not self.predict_x_start:
x_recon.clamp_(-1., 1.) x_start.clamp_(-1., 1.)
if self.predict_x_start and self.sampling_clamp_l2norm: if self.predict_x_start and self.sampling_clamp_l2norm:
x_recon = l2norm(x_recon) * self.image_embed_scale x_start = l2norm(x_start) * self.image_embed_scale
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad() @torch.no_grad()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.): def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale) model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
noise = torch.randn_like(x) noise = torch.randn_like(x)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred, x_start
@torch.no_grad() @torch.no_grad()
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.): def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
batch, device = shape[0], self.device batch, device = shape[0], self.device
image_embed = torch.randn(shape, device = device) image_embed = torch.randn(shape, device = device)
x_start = None # for self-conditioning
if self.init_image_embed_l2norm: if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps): for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
times = torch.full((batch,), i, device = device, dtype = torch.long) times = torch.full((batch,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
self_cond = x_start if self.net.self_cond else None
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)
if self.sampling_final_clamp_l2norm and self.predict_x_start: if self.sampling_final_clamp_l2norm and self.predict_x_start:
image_embed = self.l2norm_clamp_embed(image_embed) image_embed = self.l2norm_clamp_embed(image_embed)
@@ -1104,6 +1262,8 @@ class DiffusionPrior(nn.Module):
image_embed = torch.randn(shape, device = device) image_embed = torch.randn(shape, device = device)
x_start = None # for self-conditioning
if self.init_image_embed_l2norm: if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale image_embed = l2norm(image_embed) * self.image_embed_scale
@@ -1113,7 +1273,9 @@ class DiffusionPrior(nn.Module):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long) time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond) self_cond = x_start if self.net.self_cond else None
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
if self.predict_x_start: if self.predict_x_start:
x_start = pred x_start = pred
@@ -1148,19 +1310,29 @@ class DiffusionPrior(nn.Module):
is_ddim = timesteps < self.noise_scheduler.num_timesteps is_ddim = timesteps < self.noise_scheduler.num_timesteps
if not is_ddim: if not is_ddim:
return self.p_sample_loop_ddpm(*args, **kwargs) normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)
else:
normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps) image_embed = normalized_image_embed / self.image_embed_scale
return image_embed
def p_losses(self, image_embed, times, text_cond, noise = None): def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed)) noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise) image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
self_cond = None
if self.net.self_cond and random.random() < 0.5:
with torch.no_grad():
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
pred = self.net( pred = self.net(
image_embed_noisy, image_embed_noisy,
times, times,
cond_drop_prob = self.cond_drop_prob, self_cond = self_cond,
text_cond_drop_prob = self.text_cond_drop_prob,
image_cond_drop_prob = self.image_cond_drop_prob,
**text_cond **text_cond
) )
@@ -1213,8 +1385,6 @@ class DiffusionPrior(nn.Module):
# retrieve original unscaled image embed # retrieve original unscaled image embed
image_embeds /= self.image_embed_scale
text_embeds = text_cond['text_embed'] text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch) text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
@@ -1309,9 +1479,34 @@ class PixelShuffleUpsample(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
def Downsample(dim, *, dim_out = None): def Downsample(dim, dim_out = None):
# https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample
# named SP-conv in the paper, but basically a pixel unshuffle
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
return nn.Conv2d(dim, dim_out, 4, 2, 1) return nn.Sequential(
Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
nn.Conv2d(dim * 4, dim_out, 1)
)
class WeightStandardizedConv2d(nn.Conv2d):
"""
https://arxiv.org/abs/1903.10520
weight standardization purportedly works synergistically with group normalization
"""
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
weight = self.weight
flattened_weights = rearrange(weight, 'o ... -> o (...)')
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
var = torch.var(flattened_weights, dim = -1, unbiased = False)
var = rearrange(var, 'o -> o 1 1 1')
weight = (weight - mean) * (var + eps).rsqrt()
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
def __init__(self, dim): def __init__(self, dim):
@@ -1331,10 +1526,13 @@ class Block(nn.Module):
self, self,
dim, dim,
dim_out, dim_out,
groups = 8 groups = 8,
weight_standardization = False
): ):
super().__init__() super().__init__()
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1) conv_klass = nn.Conv2d if not weight_standardization else WeightStandardizedConv2d
self.project = conv_klass(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out) self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU() self.act = nn.SiLU()
@@ -1357,7 +1555,9 @@ class ResnetBlock(nn.Module):
*, *,
cond_dim = None, cond_dim = None,
time_cond_dim = None, time_cond_dim = None,
groups = 8 groups = 8,
weight_standardization = False,
cosine_sim_cross_attn = False
): ):
super().__init__() super().__init__()
@@ -1377,12 +1577,13 @@ class ResnetBlock(nn.Module):
'b (h w) c', 'b (h w) c',
CrossAttention( CrossAttention(
dim = dim_out, dim = dim_out,
context_dim = cond_dim context_dim = cond_dim,
cosine_sim = cosine_sim_cross_attn
) )
) )
self.block1 = Block(dim, dim_out, groups = groups) self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
self.block2 = Block(dim_out, dim_out, groups = groups) self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardization = weight_standardization)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb = None, cond = None): def forward(self, x, time_emb = None, cond = None):
@@ -1412,11 +1613,12 @@ class CrossAttention(nn.Module):
heads = 8, heads = 8,
dropout = 0., dropout = 0.,
norm_context = False, norm_context = False,
pb_relax_alpha = 32 ** 2 cosine_sim = False,
cosine_sim_scale = 16
): ):
super().__init__() super().__init__()
self.pb_relax_alpha = pb_relax_alpha self.cosine_sim = cosine_sim
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1) self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
self.heads = heads self.heads = heads
inner_dim = dim_head * heads inner_dim = dim_head * heads
@@ -1452,7 +1654,10 @@ class CrossAttention(nn.Module):
k = torch.cat((nk, k), dim = -2) k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2) v = torch.cat((nv, v), dim = -2)
q = q * self.scale if self.cosine_sim:
q, k = map(l2norm, (q, k))
q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
sim = einsum('b h i d, b h j d -> b h i j', q, k) sim = einsum('b h i d, b h j d -> b h i j', q, k)
max_neg_value = -torch.finfo(sim.dtype).max max_neg_value = -torch.finfo(sim.dtype).max
@@ -1462,10 +1667,8 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j') mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value) sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True).detach() attn = sim.softmax(dim = -1, dtype = torch.float32)
sim = sim * self.pb_relax_alpha attn = attn.type(sim.dtype)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v) out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, 'b h n d -> b n (h d)')
@@ -1476,7 +1679,8 @@ class LinearAttention(nn.Module):
self, self,
dim, dim,
dim_head = 32, dim_head = 32,
heads = 8 heads = 8,
**kwargs
): ):
super().__init__() super().__init__()
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
@@ -1494,6 +1698,7 @@ class LinearAttention(nn.Module):
def forward(self, fmap): def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:] h, x, y = self.heads, *fmap.shape[-2:]
seq_len = x * y
fmap = self.norm(fmap) fmap = self.norm(fmap)
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1) q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
@@ -1503,7 +1708,9 @@ class LinearAttention(nn.Module):
k = k.softmax(dim = -2) k = k.softmax(dim = -2)
q = q * self.scale q = q * self.scale
v = v / (x * y) v = l2norm(v)
k, v = map(lambda t: t / math.sqrt(seq_len), (k, v))
context = einsum('b n d, b n e -> b d e', k, v) context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context) out = einsum('b n d, b d e -> b n e', q, context)
@@ -1590,7 +1797,10 @@ class Unet(nn.Module):
attn_heads = 16, attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
self_cond = False, # set this to True to use the self-conditioning technique from - https://arxiv.org/abs/2208.04202
sparse_attn = False, sparse_attn = False,
cosine_sim_cross_attn = False,
cosine_sim_self_attn = False,
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False, cond_on_text_encodings = False,
max_text_len = 256, max_text_len = 256,
@@ -1599,6 +1809,7 @@ class Unet(nn.Module):
init_dim = None, init_dim = None,
init_conv_kernel_size = 7, init_conv_kernel_size = 7,
resnet_groups = 8, resnet_groups = 8,
resnet_weight_standardization = False,
num_resnet_blocks = 2, num_resnet_blocks = 2,
init_cross_embed = True, init_cross_embed = True,
init_cross_embed_kernel_sizes = (3, 7, 15), init_cross_embed_kernel_sizes = (3, 7, 15),
@@ -1609,6 +1820,7 @@ class Unet(nn.Module):
pixel_shuffle_upsample = True, pixel_shuffle_upsample = True,
final_conv_kernel_size = 1, final_conv_kernel_size = 1,
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
checkpoint_during_training = False,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -1622,12 +1834,21 @@ class Unet(nn.Module):
self.lowres_cond = lowres_cond self.lowres_cond = lowres_cond
# whether to do self conditioning
self.self_cond = self_cond
# determine dimensions # determine dimensions
self.channels = channels self.channels = channels
self.channels_out = default(channels_out, channels) self.channels_out = default(channels_out, channels)
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis # initial number of channels depends on
# (1) low resolution conditioning from cascading ddpm paper, conditioned on previous unet output in the cascade
# (2) self conditioning (bit diffusion paper)
init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
init_dim = default(init_dim, dim) init_dim = default(init_dim, dim)
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2) self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
@@ -1711,7 +1932,7 @@ class Unet(nn.Module):
# attention related params # attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head) attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim = cosine_sim_self_attn)
self_attn = cast_tuple(self_attn, num_stages) self_attn = cast_tuple(self_attn, num_stages)
@@ -1734,9 +1955,13 @@ class Unet(nn.Module):
upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
# prepare resnet klass
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn, weight_standardization = resnet_weight_standardization)
# give memory efficient unet an initial resnet block # give memory efficient unet an initial resnet block
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None self.init_resnet_block = resnet_block(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
# layers # layers
@@ -1763,17 +1988,17 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([ self.downs.append(nn.ModuleList([
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None, downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups), resnet_block(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]), nn.ModuleList([resnet_block(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention, attention,
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1) downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
])) ]))
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) self.mid_block1 = resnet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = create_self_attn(mid_dim) self.mid_attn = create_self_attn(mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) self.mid_block2 = resnet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))): for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):
is_last = ind >= (len(in_out) - 1) is_last = ind >= (len(in_out) - 1)
@@ -1790,8 +2015,8 @@ class Unet(nn.Module):
upsample_combiner_dims.append(dim_out) upsample_combiner_dims.append(dim_out)
self.ups.append(nn.ModuleList([ self.ups.append(nn.ModuleList([
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups), resnet_block(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]), nn.ModuleList([resnet_block(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
attention, attention,
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity() upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
])) ]))
@@ -1807,7 +2032,7 @@ class Unet(nn.Module):
# a final resnet block # a final resnet block
self.final_resnet_block = ResnetBlock(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) self.final_resnet_block = resnet_block(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
out_dim_in = dim + (channels if lowres_cond else 0) out_dim_in = dim + (channels if lowres_cond else 0)
@@ -1815,6 +2040,10 @@ class Unet(nn.Module):
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
# whether to checkpoint during training
self.checkpoint_during_training = checkpoint_during_training
# if the current settings for the unet are not correct # if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings # for cascading DDPM, then reinit the unet with the right settings
def cast_model_parameters( def cast_model_parameters(
@@ -1872,7 +2101,9 @@ class Unet(nn.Module):
image_cond_drop_prob = 0., image_cond_drop_prob = 0.,
text_cond_drop_prob = 0., text_cond_drop_prob = 0.,
blur_sigma = None, blur_sigma = None,
blur_kernel_size = None blur_kernel_size = None,
disable_checkpoint = False,
self_cond = None
): ):
batch_size, device = x.shape[0], x.device batch_size, device = x.shape[0], x.device
@@ -1880,6 +2111,14 @@ class Unet(nn.Module):
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present' assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
# concat self conditioning, if needed
if self.self_cond:
self_cond = default(self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x, self_cond), dim = 1)
# concat low resolution conditioning
if exists(lowres_cond_img): if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1) x = torch.cat((x, lowres_cond_img), dim = 1)
@@ -1994,17 +2233,29 @@ class Unet(nn.Module):
c = self.norm_cond(c) c = self.norm_cond(c)
mid_c = self.norm_mid_cond(mid_c) mid_c = self.norm_mid_cond(mid_c)
# gradient checkpointing
can_checkpoint = self.training and self.checkpoint_during_training and not disable_checkpoint
apply_checkpoint_fn = make_checkpointable if can_checkpoint else identity
# make checkpointable modules
init_resnet_block, mid_block1, mid_attn, mid_block2, final_resnet_block = [maybe(apply_checkpoint_fn)(module) for module in (self.init_resnet_block, self.mid_block1, self.mid_attn, self.mid_block2, self.final_resnet_block)]
can_checkpoint_cond = lambda m: isinstance(m, ResnetBlock)
downs, ups = [maybe(apply_checkpoint_fn)(m, condition = can_checkpoint_cond) for m in (self.downs, self.ups)]
# initial resnet block # initial resnet block
if exists(self.init_resnet_block): if exists(init_resnet_block):
x = self.init_resnet_block(x, t) x = init_resnet_block(x, t)
# go through the layers of the unet, down and up # go through the layers of the unet, down and up
down_hiddens = [] down_hiddens = []
up_hiddens = [] up_hiddens = []
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs: for pre_downsample, init_block, resnet_blocks, attn, post_downsample in downs:
if exists(pre_downsample): if exists(pre_downsample):
x = pre_downsample(x) x = pre_downsample(x)
@@ -2020,16 +2271,16 @@ class Unet(nn.Module):
if exists(post_downsample): if exists(post_downsample):
x = post_downsample(x) x = post_downsample(x)
x = self.mid_block1(x, t, mid_c) x = mid_block1(x, t, mid_c)
if exists(self.mid_attn): if exists(mid_attn):
x = self.mid_attn(x) x = mid_attn(x)
x = self.mid_block2(x, t, mid_c) x = mid_block2(x, t, mid_c)
connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1) connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
for init_block, resnet_blocks, attn, upsample in self.ups: for init_block, resnet_blocks, attn, upsample in ups:
x = connect_skip(x) x = connect_skip(x)
x = init_block(x, t, c) x = init_block(x, t, c)
@@ -2046,7 +2297,7 @@ class Unet(nn.Module):
x = torch.cat((x, r), dim = 1) x = torch.cat((x, r), dim = 1)
x = self.final_resnet_block(x, t) x = final_resnet_block(x, t)
if exists(lowres_cond_img): if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1) x = torch.cat((x, lowres_cond_img), dim = 1)
@@ -2399,6 +2650,14 @@ class Decoder(nn.Module):
index = unet_number - 1 index = unet_number - 1
return self.unets[index] return self.unets[index]
def parse_unet_output(self, learned_variance, output):
var_interp_frac_unnormalized = None
if learned_variance:
output, var_interp_frac_unnormalized = output.chunk(2, dim = 1)
return UnetOutput(output, var_interp_frac_unnormalized)
@contextmanager @contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None): def one_unet_in_gpu(self, unet_number = None, unet = None):
assert exists(unet_number) ^ exists(unet) assert exists(unet_number) ^ exists(unet)
@@ -2437,23 +2696,22 @@ class Decoder(nn.Module):
x = x.clamp(-s, s) / s x = x.clamp(-s, s) / s
return x return x
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None): def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)) model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
if learned_variance: pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
if predict_x_start: if predict_x_start:
x_recon = pred x_start = pred
else: else:
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised: if clip_denoised:
x_recon = self.dynamic_threshold(x_recon) x_start = self.dynamic_threshold(x_start)
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
if learned_variance: if learned_variance:
# if learned variance, posterio variance and posterior log variance are predicted by the network # if learned variance, posterio variance and posterior log variance are predicted by the network
@@ -2469,16 +2727,17 @@ class Decoder(nn.Module):
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp() posterior_variance = posterior_log_variance.exp()
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad() @torch.no_grad()
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None): def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level) model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
noise = torch.randn_like(x) noise = torch.randn_like(x)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred, x_start
@torch.no_grad() @torch.no_grad()
def p_sample_loop_ddpm( def p_sample_loop_ddpm(
@@ -2504,6 +2763,8 @@ class Decoder(nn.Module):
b = shape[0] b = shape[0]
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
x_start = None # for self-conditioning
is_inpaint = exists(inpaint_image) is_inpaint = exists(inpaint_image)
resample_times = inpaint_resample_times if is_inpaint else 1 resample_times = inpaint_resample_times if is_inpaint else 1
@@ -2531,13 +2792,16 @@ class Decoder(nn.Module):
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times) noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask) img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
img = self.p_sample( self_cond = x_start if unet.self_cond else None
img, x_start = self.p_sample(
unet, unet,
img, img,
times, times,
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
cond_scale = cond_scale, cond_scale = cond_scale,
self_cond = self_cond,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level, lowres_noise_level = lowres_noise_level,
predict_x_start = predict_x_start, predict_x_start = predict_x_start,
@@ -2596,6 +2860,8 @@ class Decoder(nn.Module):
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
x_start = None # for self-conditioning
if not is_latent_diffusion: if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
@@ -2616,10 +2882,11 @@ class Decoder(nn.Module):
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond) noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask) img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level) self_cond = x_start if unet.self_cond else None
if learned_variance: unet_output = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
pred, _ = pred.chunk(2, dim = 1)
pred, _ = self.parse_unet_output(learned_variance, unet_output)
if predict_x_start: if predict_x_start:
x_start = pred x_start = pred
@@ -2676,21 +2943,37 @@ class Decoder(nn.Module):
x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise) x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
model_output = unet( # unet kwargs
x_noisy,
times, unet_kwargs = dict(
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level, lowres_noise_level = lowres_noise_level,
)
# self conditioning
self_cond = None
if unet.self_cond and random.random() < 0.5:
with torch.no_grad():
unet_output = unet(x_noisy, times, **unet_kwargs)
self_cond, _ = self.parse_unet_output(learned_variance, unet_output)
self_cond = self_cond.detach()
# forward to get model prediction
unet_output = unet(
x_noisy,
times,
**unet_kwargs,
self_cond = self_cond,
image_cond_drop_prob = self.image_cond_drop_prob, image_cond_drop_prob = self.image_cond_drop_prob,
text_cond_drop_prob = self.text_cond_drop_prob, text_cond_drop_prob = self.text_cond_drop_prob,
) )
if learned_variance: pred, _ = self.parse_unet_output(learned_variance, unet_output)
pred, _ = model_output.chunk(2, dim = 1)
else:
pred = model_output
target = noise if not predict_x_start else x_start target = noise if not predict_x_start else x_start
@@ -2713,7 +2996,7 @@ class Decoder(nn.Module):
# if learning the variance, also include the extra weight kl loss # if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times) true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output) model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = unet_output)
# kl loss with detached model predicted mean, for stability reasons as in paper # kl loss with detached model predicted mean, for stability reasons as in paper

View File

@@ -241,7 +241,7 @@ class DecoderConfig(BaseModel):
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3 channels: int = 3
timesteps: int = 1000 timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable[int]] = None sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None
loss_type: str = 'l2' loss_type: str = 'l2'
beta_schedule: ListOrTuple[str] = None # None means all cosine beta_schedule: ListOrTuple[str] = None # None means all cosine
learned_variance: SingularOrIterable[bool] = True learned_variance: SingularOrIterable[bool] = True

View File

@@ -9,7 +9,7 @@ from collections.abc import Iterable
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
from torch.cuda.amp import autocast, GradScaler from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
@@ -181,7 +181,8 @@ class DiffusionPriorTrainer(nn.Module):
eps = 1e-6, eps = 1e-6,
max_grad_norm = None, max_grad_norm = None,
group_wd_params = True, group_wd_params = True,
warmup_steps = 1, warmup_steps = None,
cosine_decay_max_steps = None,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -234,7 +235,10 @@ class DiffusionPriorTrainer(nn.Module):
**kwargs **kwargs
) )
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0) if exists(cosine_decay_max_steps):
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
else:
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
@@ -271,6 +275,7 @@ class DiffusionPriorTrainer(nn.Module):
# FIXME: LambdaLR can't be saved due to pickling issues # FIXME: LambdaLR can't be saved due to pickling issues
save_obj = dict( save_obj = dict(
optimizer = self.optimizer.state_dict(), optimizer = self.optimizer.state_dict(),
scheduler = self.scheduler.state_dict(),
warmup_scheduler = self.warmup_scheduler, warmup_scheduler = self.warmup_scheduler,
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(), model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
version = version.parse(__version__), version = version.parse(__version__),
@@ -317,7 +322,9 @@ class DiffusionPriorTrainer(nn.Module):
# unwrap the model when loading from checkpoint # unwrap the model when loading from checkpoint
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict) self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device)) self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
self.optimizer.load_state_dict(loaded_obj['optimizer']) self.optimizer.load_state_dict(loaded_obj['optimizer'])
self.scheduler.load_state_dict(loaded_obj['scheduler'])
# set warmupstep # set warmupstep
if exists(self.warmup_scheduler): if exists(self.warmup_scheduler):
@@ -350,7 +357,8 @@ class DiffusionPriorTrainer(nn.Module):
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy" # accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
if not self.accelerator.optimizer_step_was_skipped: if not self.accelerator.optimizer_step_was_skipped:
with self.warmup_scheduler.dampening(): sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
with sched_context():
self.scheduler.step() self.scheduler.step()
if self.use_ema: if self.use_ema:
@@ -433,6 +441,7 @@ class DecoderTrainer(nn.Module):
wd = 1e-2, wd = 1e-2,
eps = 1e-8, eps = 1e-8,
warmup_steps = None, warmup_steps = None,
cosine_decay_max_steps = None,
max_grad_norm = 0.5, max_grad_norm = 0.5,
amp = False, amp = False,
group_wd_params = True, group_wd_params = True,
@@ -454,7 +463,7 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay # be able to finely customize learning rate, weight decay
# per unet # per unet
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps)) lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4' assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
@@ -462,7 +471,7 @@ class DecoderTrainer(nn.Module):
schedulers = [] schedulers = []
warmup_schedulers = [] warmup_schedulers = []
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps): for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
if isinstance(unet, nn.Identity): if isinstance(unet, nn.Identity):
optimizers.append(None) optimizers.append(None)
schedulers.append(None) schedulers.append(None)
@@ -478,7 +487,11 @@ class DecoderTrainer(nn.Module):
) )
optimizers.append(optimizer) optimizers.append(optimizer)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
if exists(unet_cosine_decay_max_steps):
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
else:
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler) warmup_schedulers.append(warmup_scheduler)
@@ -558,9 +571,15 @@ class DecoderTrainer(nn.Module):
for ind in range(0, self.num_unets): for ind in range(0, self.num_unets):
optimizer_key = f'optim{ind}' optimizer_key = f'optim{ind}'
scheduler_key = f'sched{ind}'
optimizer = getattr(self, optimizer_key) optimizer = getattr(self, optimizer_key)
state_dict = optimizer.state_dict() if optimizer is not None else None scheduler = getattr(self, scheduler_key)
save_obj = {**save_obj, optimizer_key: state_dict}
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
if self.use_ema: if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
@@ -581,10 +600,18 @@ class DecoderTrainer(nn.Module):
optimizer_key = f'optim{ind}' optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key) optimizer = getattr(self, optimizer_key)
scheduler_key = f'sched{ind}'
scheduler = getattr(self, scheduler_key)
warmup_scheduler = self.warmup_schedulers[ind] warmup_scheduler = self.warmup_schedulers[ind]
if optimizer is not None:
if exists(optimizer):
optimizer.load_state_dict(loaded_obj[optimizer_key]) optimizer.load_state_dict(loaded_obj[optimizer_key])
if exists(scheduler):
scheduler.load_state_dict(loaded_obj[scheduler_key])
if exists(warmup_scheduler): if exists(warmup_scheduler):
warmup_scheduler.last_step = last_step warmup_scheduler.last_step = last_step

View File

@@ -1 +1 @@
__version__ = '1.2.2' __version__ = '1.10.0'

View File

@@ -26,7 +26,7 @@ setup(
install_requires=[ install_requires=[
'accelerate', 'accelerate',
'click', 'click',
'clip-anytorch', 'clip-anytorch>=2.4.0',
'coca-pytorch>=0.0.5', 'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.7', 'ema-pytorch>=0.0.7',
'einops>=0.4', 'einops>=0.4',

View File

@@ -134,7 +134,7 @@ def get_example_data(dataloader, device, n=5):
break break
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n])) return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True): def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
""" """
Takes example data and generates images from the embeddings Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions Returns three lists: real images, generated images, and captions
@@ -144,7 +144,9 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
if img_embeddings[0] is None: if img_embeddings[0] is None:
# Generate image embeddings from clip # Generate image embeddings from clip
imgs_tensor = torch.stack(real_images) imgs_tensor = torch.stack(real_images)
img_embeddings, *_ = trainer.embed_image(imgs_tensor) assert clip is not None, "clip is None, but img_embeddings is None"
imgs_tensor.to(device=device)
img_embeddings, img_encoding = clip.embed_image(imgs_tensor)
sample_params["image_embed"] = img_embeddings sample_params["image_embed"] = img_embeddings
else: else:
# Then we are using precomputed image embeddings # Then we are using precomputed image embeddings
@@ -153,8 +155,10 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
if condition_on_text_encodings: if condition_on_text_encodings:
if text_embeddings[0] is None: if text_embeddings[0] is None:
# Generate text embeddings from text # Generate text embeddings from text
assert clip is not None, "clip is None, but text_embeddings is None"
tokenized_texts = tokenize(txts, truncate=True) tokenized_texts = tokenize(txts, truncate=True)
sample_params["text"] = tokenized_texts text_embed, text_encodings = clip.embed_text(tokenized_texts)
sample_params["text_encodings"] = text_encodings
else: else:
# Then we are using precomputed text embeddings # Then we are using precomputed text embeddings
text_embeddings = torch.stack(text_embeddings) text_embeddings = torch.stack(text_embeddings)
@@ -166,7 +170,7 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
sample_params["image"] = torch.stack(real_images) sample_params["image"] = torch.stack(real_images)
if device is not None: if device is not None:
sample_params["_device"] = device sample_params["_device"] = device
samples = trainer.sample(**sample_params) samples = trainer.sample(**sample_params, _cast_deepspeed_precision=False) # At sampling time we don't want to cast to FP16
generated_images = list(samples) generated_images = list(samples)
captions = [text_prepend + txt for txt in txts] captions = [text_prepend + txt for txt in txts]
if match_image_size: if match_image_size:
@@ -174,15 +178,15 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images] real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
return real_images, generated_images, captions return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""): def generate_grid_samples(trainer, examples, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
""" """
Generates samples and uses torchvision to put them in a side by side grid for easy viewing Generates samples and uses torchvision to put them in a side by side grid for easy viewing
""" """
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend) real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=None, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
""" """
Computes evaluation metrics for the decoder Computes evaluation metrics for the decoder
""" """
@@ -192,7 +196,7 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, conditi
if len(examples) == 0: if len(examples) == 0:
print("No data to evaluate. Check that your dataloader has shards.") print("No data to evaluate. Check that your dataloader has shards.")
return metrics return metrics
real_images, generated_images, captions = generate_samples(trainer, examples, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device) real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float) real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float) generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8 # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
@@ -265,6 +269,7 @@ def train(
accelerator: Accelerator, accelerator: Accelerator,
tracker: Tracker, tracker: Tracker,
inference_device, inference_device,
clip=None,
evaluate_config=None, evaluate_config=None,
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
validation_samples = None, validation_samples = None,
@@ -371,15 +376,19 @@ def train(
forward_params['image_embed'] = img_emb forward_params['image_embed'] = img_emb
else: else:
# Forward pass automatically generates embedding # Forward pass automatically generates embedding
pass assert clip is not None
img_embed, img_encoding = clip.embed_image(img)
forward_params['image_embed'] = img_embed
if condition_on_text_encodings: if condition_on_text_encodings:
if has_text_embedding: if has_text_embedding:
forward_params['text_encodings'] = text_emb forward_params['text_encodings'] = text_emb
else: else:
# Then we need to pass the text instead # Then we need to pass the text instead
tokenized_texts = tokenize(txt, truncate=True) assert clip is not None
tokenized_texts = tokenize(txt, truncate=True).to(inference_device)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})" assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts text_embed, text_encodings = clip.embed_text(tokenized_texts)
forward_params['text_encodings'] = text_encodings
loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device) loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)
trainer.update(unet_number=unet) trainer.update(unet_number=unet)
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
@@ -419,7 +428,7 @@ def train(
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen) save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
if exists(n_sample_images) and n_sample_images > 0: if exists(n_sample_images) and n_sample_images > 0:
trainer.eval() trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ") train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
if epoch_samples is not None and sample >= epoch_samples: if epoch_samples is not None and sample >= epoch_samples:
@@ -462,15 +471,19 @@ def train(
forward_params['image_embed'] = img_emb.float() forward_params['image_embed'] = img_emb.float()
else: else:
# Forward pass automatically generates embedding # Forward pass automatically generates embedding
pass assert clip is not None
img_embed, img_encoding = clip.embed_image(img)
forward_params['image_embed'] = img_embed
if condition_on_text_encodings: if condition_on_text_encodings:
if has_text_embedding: if has_text_embedding:
forward_params['text_encodings'] = text_emb.float() forward_params['text_encodings'] = text_emb.float()
else: else:
# Then we need to pass the text instead # Then we need to pass the text instead
assert clip is not None
tokenized_texts = tokenize(txt, truncate=True) tokenized_texts = tokenize(txt, truncate=True)
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})" assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
forward_params['text'] = tokenized_texts text_embed, text_encodings = clip.embed_text(tokenized_texts)
forward_params['text_encodings'] = text_encodings
loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device) loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)
average_val_loss_tensor[0, unet-1] += loss average_val_loss_tensor[0, unet-1] += loss
@@ -498,7 +511,7 @@ def train(
if next_task == 'eval': if next_task == 'eval':
if exists(evaluate_config): if exists(evaluate_config):
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale) evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, clip=clip, inference_device=inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
if is_master: if is_master:
tracker.log(evaluation, step=step()) tracker.log(evaluation, step=step())
next_task = 'sample' next_task = 'sample'
@@ -509,8 +522,8 @@ def train(
# Generate examples and save the model if we are the master # Generate examples and save the model if we are the master
# Generate sample images # Generate sample images
print(print_ribbon(f"Sampling Set {epoch}", repeat=40)) print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
test_images, test_captions = generate_grid_samples(trainer, test_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ") test_images, test_captions = generate_grid_samples(trainer, test_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
train_images, train_captions = generate_grid_samples(trainer, train_example_data, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ") train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step()) tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
@@ -532,6 +545,7 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
"NumProcesses": accelerator.num_processes, "NumProcesses": accelerator.num_processes,
"MixedPrecision": accelerator.mixed_precision "MixedPrecision": accelerator.mixed_precision
} }
accelerator.wait_for_everyone() # If nodes arrive at this point at different times they might try to autoresume the current run which makes no sense and will cause errors
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy) tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
tracker.save_config(config_path, config_name='decoder_config.json') tracker.save_config(config_path, config_name='decoder_config.json')
tracker.add_save_metadata(state_dict_key='config', metadata=config.dict()) tracker.add_save_metadata(state_dict_key='config', metadata=config.dict())
@@ -556,10 +570,6 @@ def initialize_training(config: TrainDecoderConfig, config_path):
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance: if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
raise ValueError("DeepSpeed fp16 mode does not support learned variance") raise ValueError("DeepSpeed fp16 mode does not support learned variance")
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
# This is an invalid configuration until we figure out how to handle this
raise ValueError("DeepSpeed does not support multi-node distributed training")
# Set up data # Set up data
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1)) all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
world_size = accelerator.num_processes world_size = accelerator.num_processes
@@ -579,6 +589,11 @@ def initialize_training(config: TrainDecoderConfig, config_path):
seed = config.seed, seed = config.seed,
) )
# If clip is in the model, we need to remove it for compatibility with deepspeed
clip = None
if config.decoder.clip is not None:
clip = config.decoder.clip.create() # Of course we keep it to use it during training, just not in the decoder as that causes issues
config.decoder.clip = None
# Create the decoder model and print basic info # Create the decoder model and print basic info
decoder = config.decoder.create() decoder = config.decoder.create()
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training)) get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
@@ -590,7 +605,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
has_text_embeddings = config.data.text_embeddings_url is not None has_text_embeddings = config.data.text_embeddings_url is not None
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets]) conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
has_clip_model = config.decoder.clip is not None has_clip_model = clip is not None
data_source_string = "" data_source_string = ""
if has_img_embeddings: if has_img_embeddings:
@@ -615,6 +630,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training") accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
train(dataloaders, decoder, accelerator, train(dataloaders, decoder, accelerator,
clip=clip,
tracker=tracker, tracker=tracker,
inference_device=accelerator.device, inference_device=accelerator.device,
evaluate_config=config.evaluate, evaluate_config=config.evaluate,