mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-15 04:24:53 +01:00
Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27f19ba7fa | ||
|
|
8f38339c2b | ||
|
|
6b9b4b9e5e | ||
|
|
44e09d5a4d | ||
|
|
34806663e3 | ||
|
|
dc816b1b6e | ||
|
|
05192ffac4 | ||
|
|
9440411954 | ||
|
|
981d407792 | ||
|
|
7c5477b26d | ||
|
|
be3bb868bf | ||
|
|
451de34871 | ||
|
|
f22e8c8741 | ||
|
|
87432e93ad | ||
|
|
d167378401 | ||
|
|
2d67d5821e | ||
|
|
748c7fe7af | ||
|
|
80046334ad | ||
|
|
36fb46a95e | ||
|
|
07abfcf45b | ||
|
|
2e35a9967d | ||
|
|
406e75043f | ||
|
|
9646dfc0e6 |
51
README.md
51
README.md
@@ -371,6 +371,7 @@ loss.backward()
|
|||||||
unet1 = Unet(
|
unet1 = Unet(
|
||||||
dim = 128,
|
dim = 128,
|
||||||
image_embed_dim = 512,
|
image_embed_dim = 512,
|
||||||
|
text_embed_dim = 512,
|
||||||
cond_dim = 128,
|
cond_dim = 128,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
dim_mults=(1, 2, 4, 8),
|
dim_mults=(1, 2, 4, 8),
|
||||||
@@ -395,7 +396,7 @@ decoder = Decoder(
|
|||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
for unet_number in (1, 2):
|
for unet_number in (1, 2):
|
||||||
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# do above for many steps
|
# do above for many steps
|
||||||
@@ -626,6 +627,18 @@ images = dalle2(
|
|||||||
# save your image (in this example, of size 256x256)
|
# save your image (in this example, of size 256x256)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Alternatively, you can also use <a href="https://github.com/mlfoundations/open_clip">Open Clip</a>
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ pip install open-clip-torch
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
from dalle2_pytorch import OpenClipAdapter
|
||||||
|
|
||||||
|
clip = OpenClipAdapter()
|
||||||
|
```
|
||||||
|
|
||||||
Now you'll just have to worry about training the Prior and the Decoder!
|
Now you'll just have to worry about training the Prior and the Decoder!
|
||||||
|
|
||||||
## Inpainting
|
## Inpainting
|
||||||
@@ -860,25 +873,23 @@ unet1 = Unet(
|
|||||||
text_embed_dim = 512,
|
text_embed_dim = 512,
|
||||||
cond_dim = 128,
|
cond_dim = 128,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
dim_mults=(1, 2, 4, 8)
|
dim_mults=(1, 2, 4, 8),
|
||||||
|
cond_on_text_encodings = True,
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
unet2 = Unet(
|
unet2 = Unet(
|
||||||
dim = 16,
|
dim = 16,
|
||||||
image_embed_dim = 512,
|
image_embed_dim = 512,
|
||||||
text_embed_dim = 512,
|
|
||||||
cond_dim = 128,
|
cond_dim = 128,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
dim_mults = (1, 2, 4, 8, 16),
|
dim_mults = (1, 2, 4, 8, 16),
|
||||||
cond_on_text_encodings = True
|
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
unet = (unet1, unet2),
|
unet = (unet1, unet2),
|
||||||
image_sizes = (128, 256),
|
image_sizes = (128, 256),
|
||||||
clip = clip,
|
clip = clip,
|
||||||
timesteps = 1000,
|
timesteps = 1000
|
||||||
condition_on_text_encodings = True
|
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
decoder_trainer = DecoderTrainer(
|
decoder_trainer = DecoderTrainer(
|
||||||
@@ -903,8 +914,8 @@ for unet_number in (1, 2):
|
|||||||
# after much training
|
# after much training
|
||||||
# you can sample from the exponentially moving averaged unets as so
|
# you can sample from the exponentially moving averaged unets as so
|
||||||
|
|
||||||
mock_image_embed = torch.randn(4, 512).cuda()
|
mock_image_embed = torch.randn(32, 512).cuda()
|
||||||
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Diffusion Prior Training
|
### Diffusion Prior Training
|
||||||
@@ -1112,7 +1123,8 @@ For detailed information on training the diffusion prior, please refer to the [d
|
|||||||
- [x] allow for unet to be able to condition non-cross attention style as well
|
- [x] allow for unet to be able to condition non-cross attention style as well
|
||||||
- [x] speed up inference, read up on papers (ddim)
|
- [x] speed up inference, read up on papers (ddim)
|
||||||
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||||
- [ ] try out the nested unet from https://arxiv.org/abs/2005.09007 after hearing several positive testimonies from researchers, for segmentation anyhow
|
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
|
||||||
|
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
|
||||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
@@ -1241,4 +1253,25 @@ For detailed information on training the diffusion prior, please refer to the [d
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{chen2022analog,
|
||||||
|
title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
|
||||||
|
author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
|
||||||
|
year = {2022},
|
||||||
|
eprint = {2208.04202},
|
||||||
|
archivePrefix = {arXiv},
|
||||||
|
primaryClass = {cs.CV}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{Qiao2019WeightS,
|
||||||
|
title = {Weight Standardization},
|
||||||
|
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
|
||||||
|
journal = {ArXiv},
|
||||||
|
year = {2019},
|
||||||
|
volume = {abs/1903.10520}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
|
|
||||||
@@ -37,6 +38,8 @@ from coca_pytorch import CoCa
|
|||||||
|
|
||||||
NAT = 1. / math.log(2.)
|
NAT = 1. / math.log(2.)
|
||||||
|
|
||||||
|
UnetOutput = namedtuple('UnetOutput', ['pred', 'var_interp_frac_unnormalized'])
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
@@ -108,6 +111,28 @@ def pad_tuple_to_length(t, length, fillvalue = None):
|
|||||||
return t
|
return t
|
||||||
return (*t, *((fillvalue,) * remain_length))
|
return (*t, *((fillvalue,) * remain_length))
|
||||||
|
|
||||||
|
# checkpointing helper function
|
||||||
|
|
||||||
|
def make_checkpointable(fn, **kwargs):
|
||||||
|
if isinstance(fn, nn.ModuleList):
|
||||||
|
return [maybe(make_checkpointable)(el, **kwargs) for el in fn]
|
||||||
|
|
||||||
|
condition = kwargs.pop('condition', None)
|
||||||
|
|
||||||
|
if exists(condition) and not condition(fn):
|
||||||
|
return fn
|
||||||
|
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(*args):
|
||||||
|
input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args])
|
||||||
|
|
||||||
|
if not input_needs_grad:
|
||||||
|
return fn(*args)
|
||||||
|
|
||||||
|
return checkpoint(fn, *args)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
# for controlling freezing of CLIP
|
# for controlling freezing of CLIP
|
||||||
|
|
||||||
def set_module_requires_grad_(module, requires_grad):
|
def set_module_requires_grad_(module, requires_grad):
|
||||||
@@ -339,6 +364,75 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
image_embed = self.clip.encode_image(image)
|
image_embed = self.clip.encode_image(image)
|
||||||
return EmbeddedImage(l2norm(image_embed.float()), None)
|
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||||
|
|
||||||
|
class OpenClipAdapter(BaseClipAdapter):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name = 'ViT-B/32',
|
||||||
|
pretrained = 'laion400m_e32'
|
||||||
|
):
|
||||||
|
import open_clip
|
||||||
|
clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)
|
||||||
|
|
||||||
|
super().__init__(clip)
|
||||||
|
self.eos_id = 49407
|
||||||
|
|
||||||
|
text_attention_final = self.find_layer('ln_final')
|
||||||
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||||
|
self.clip_normalize = preprocess.transforms[-1]
|
||||||
|
self.cleared = False
|
||||||
|
|
||||||
|
def find_layer(self, layer):
|
||||||
|
modules = dict([*self.clip.named_modules()])
|
||||||
|
return modules.get(layer, None)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
if self.cleared:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.handle()
|
||||||
|
|
||||||
|
def _hook(self, _, inputs, outputs):
|
||||||
|
self.text_encodings = outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dim_latent(self):
|
||||||
|
return 512
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_size(self):
|
||||||
|
return self.clip.visual.image_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_channels(self):
|
||||||
|
return 3
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_text_len(self):
|
||||||
|
return self.clip.context_length
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def embed_text(self, text):
|
||||||
|
text = text[..., :self.max_text_len]
|
||||||
|
|
||||||
|
is_eos_id = (text == self.eos_id)
|
||||||
|
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
|
||||||
|
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
|
||||||
|
assert not self.cleared
|
||||||
|
|
||||||
|
text_embed = self.clip.encode_text(text)
|
||||||
|
text_encodings = self.text_encodings
|
||||||
|
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
|
||||||
|
del self.text_encodings
|
||||||
|
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def embed_image(self, image):
|
||||||
|
assert not self.cleared
|
||||||
|
image = self.validate_and_resize_image(image)
|
||||||
|
image = self.clip_normalize(image)
|
||||||
|
image_embed = self.clip.encode_image(image)
|
||||||
|
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||||
|
|
||||||
# classifier free guidance functions
|
# classifier free guidance functions
|
||||||
|
|
||||||
def prob_mask_like(shape, prob, device):
|
def prob_mask_like(shape, prob, device):
|
||||||
@@ -547,34 +641,40 @@ class NoiseScheduler(nn.Module):
|
|||||||
# diffusion prior
|
# diffusion prior
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
def __init__(self, dim, eps = 1e-5, stable = False):
|
def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
self.fp16_eps = fp16_eps
|
||||||
self.stable = stable
|
self.stable = stable
|
||||||
self.g = nn.Parameter(torch.ones(dim))
|
self.g = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
|
||||||
|
|
||||||
if self.stable:
|
if self.stable:
|
||||||
x = x / x.amax(dim = -1, keepdim = True).detach()
|
x = x / x.amax(dim = -1, keepdim = True).detach()
|
||||||
|
|
||||||
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
|
||||||
mean = torch.mean(x, dim = -1, keepdim = True)
|
mean = torch.mean(x, dim = -1, keepdim = True)
|
||||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
return (x - mean) * (var + eps).rsqrt() * self.g
|
||||||
|
|
||||||
class ChanLayerNorm(nn.Module):
|
class ChanLayerNorm(nn.Module):
|
||||||
def __init__(self, dim, eps = 1e-5, stable = False):
|
def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
self.fp16_eps = fp16_eps
|
||||||
self.stable = stable
|
self.stable = stable
|
||||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
|
||||||
|
|
||||||
if self.stable:
|
if self.stable:
|
||||||
x = x / x.amax(dim = 1, keepdim = True).detach()
|
x = x / x.amax(dim = 1, keepdim = True).detach()
|
||||||
|
|
||||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||||
return (x - mean) * (var + self.eps).rsqrt() * self.g
|
return (x - mean) * (var + eps).rsqrt() * self.g
|
||||||
|
|
||||||
class Residual(nn.Module):
|
class Residual(nn.Module):
|
||||||
def __init__(self, fn):
|
def __init__(self, fn):
|
||||||
@@ -695,11 +795,12 @@ class Attention(nn.Module):
|
|||||||
dropout = 0.,
|
dropout = 0.,
|
||||||
causal = False,
|
causal = False,
|
||||||
rotary_emb = None,
|
rotary_emb = None,
|
||||||
pb_relax_alpha = 128
|
cosine_sim = True,
|
||||||
|
cosine_sim_scale = 16
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pb_relax_alpha = pb_relax_alpha
|
self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
|
||||||
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
|
self.cosine_sim = cosine_sim
|
||||||
|
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
@@ -739,6 +840,13 @@ class Attention(nn.Module):
|
|||||||
k = torch.cat((nk, k), dim = -2)
|
k = torch.cat((nk, k), dim = -2)
|
||||||
v = torch.cat((nv, v), dim = -2)
|
v = torch.cat((nv, v), dim = -2)
|
||||||
|
|
||||||
|
# whether to use cosine sim
|
||||||
|
|
||||||
|
if self.cosine_sim:
|
||||||
|
q, k = map(l2norm, (q, k))
|
||||||
|
|
||||||
|
q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
|
||||||
|
|
||||||
# calculate query / key similarities
|
# calculate query / key similarities
|
||||||
|
|
||||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||||
@@ -764,10 +872,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||||
sim = sim * self.pb_relax_alpha
|
|
||||||
|
|
||||||
attn = sim.softmax(dim = -1)
|
|
||||||
attn = self.dropout(attn)
|
attn = self.dropout(attn)
|
||||||
|
|
||||||
# aggregate values
|
# aggregate values
|
||||||
@@ -834,9 +939,12 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
num_image_embeds = 1,
|
num_image_embeds = 1,
|
||||||
num_text_embeds = 1,
|
num_text_embeds = 1,
|
||||||
max_text_len = 256,
|
max_text_len = 256,
|
||||||
|
self_cond = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
self.num_time_embeds = num_time_embeds
|
self.num_time_embeds = num_time_embeds
|
||||||
self.num_image_embeds = num_image_embeds
|
self.num_image_embeds = num_image_embeds
|
||||||
self.num_text_embeds = num_text_embeds
|
self.num_text_embeds = num_text_embeds
|
||||||
@@ -864,6 +972,10 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
self.max_text_len = max_text_len
|
self.max_text_len = max_text_len
|
||||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
|
||||||
|
|
||||||
|
# whether to use self conditioning, Hinton's group's new ddpm technique
|
||||||
|
|
||||||
|
self.self_cond = self_cond
|
||||||
|
|
||||||
def forward_with_cond_scale(
|
def forward_with_cond_scale(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
@@ -885,12 +997,19 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
*,
|
*,
|
||||||
text_embed,
|
text_embed,
|
||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
|
self_cond = None,
|
||||||
cond_drop_prob = 0.
|
cond_drop_prob = 0.
|
||||||
):
|
):
|
||||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||||
|
|
||||||
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
|
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
|
||||||
|
|
||||||
|
# setup self conditioning
|
||||||
|
|
||||||
|
if self.self_cond:
|
||||||
|
self_cond = default(self_cond, lambda: torch.zeros(batch, self.dim, device = device, dtype = dtype))
|
||||||
|
self_cond = rearrange(self_cond, 'b d -> b 1 d')
|
||||||
|
|
||||||
# in section 2.2, last paragraph
|
# in section 2.2, last paragraph
|
||||||
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||||
|
|
||||||
@@ -940,13 +1059,16 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||||
# but let's just do it right
|
# but let's just do it right
|
||||||
|
|
||||||
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
|
attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds
|
||||||
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||||
|
|
||||||
time_embed = self.to_time_embeds(diffusion_timesteps)
|
time_embed = self.to_time_embeds(diffusion_timesteps)
|
||||||
|
|
||||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||||
|
|
||||||
|
if self.self_cond:
|
||||||
|
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
|
||||||
|
|
||||||
tokens = torch.cat((
|
tokens = torch.cat((
|
||||||
text_encodings,
|
text_encodings,
|
||||||
text_embed,
|
text_embed,
|
||||||
@@ -1048,45 +1170,50 @@ class DiffusionPrior(nn.Module):
|
|||||||
def l2norm_clamp_embed(self, image_embed):
|
def l2norm_clamp_embed(self, image_embed):
|
||||||
return l2norm(image_embed) * self.image_embed_scale
|
return l2norm(image_embed) * self.image_embed_scale
|
||||||
|
|
||||||
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):
|
||||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||||
|
|
||||||
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
|
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
|
||||||
|
|
||||||
if self.predict_x_start:
|
if self.predict_x_start:
|
||||||
x_recon = pred
|
x_start = pred
|
||||||
else:
|
else:
|
||||||
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||||
|
|
||||||
if clip_denoised and not self.predict_x_start:
|
if clip_denoised and not self.predict_x_start:
|
||||||
x_recon.clamp_(-1., 1.)
|
x_start.clamp_(-1., 1.)
|
||||||
|
|
||||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||||
x_recon = l2norm(x_recon) * self.image_embed_scale
|
x_start = l2norm(x_start) * self.image_embed_scale
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
|
||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance, x_start
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
|
def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||||
noise = torch.randn_like(x)
|
noise = torch.randn_like(x)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
return pred, x_start
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
|
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
|
||||||
batch, device = shape[0], self.device
|
batch, device = shape[0], self.device
|
||||||
|
|
||||||
image_embed = torch.randn(shape, device = device)
|
image_embed = torch.randn(shape, device = device)
|
||||||
|
x_start = None # for self-conditioning
|
||||||
|
|
||||||
if self.init_image_embed_l2norm:
|
if self.init_image_embed_l2norm:
|
||||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||||
|
|
||||||
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
|
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
|
||||||
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
||||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
|
||||||
|
self_cond = x_start if self.net.self_cond else None
|
||||||
|
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)
|
||||||
|
|
||||||
if self.sampling_final_clamp_l2norm and self.predict_x_start:
|
if self.sampling_final_clamp_l2norm and self.predict_x_start:
|
||||||
image_embed = self.l2norm_clamp_embed(image_embed)
|
image_embed = self.l2norm_clamp_embed(image_embed)
|
||||||
@@ -1104,6 +1231,8 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
image_embed = torch.randn(shape, device = device)
|
image_embed = torch.randn(shape, device = device)
|
||||||
|
|
||||||
|
x_start = None # for self-conditioning
|
||||||
|
|
||||||
if self.init_image_embed_l2norm:
|
if self.init_image_embed_l2norm:
|
||||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||||
|
|
||||||
@@ -1113,7 +1242,9 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||||
|
|
||||||
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
|
self_cond = x_start if self.net.self_cond else None
|
||||||
|
|
||||||
|
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
|
||||||
|
|
||||||
if self.predict_x_start:
|
if self.predict_x_start:
|
||||||
x_start = pred
|
x_start = pred
|
||||||
@@ -1148,18 +1279,27 @@ class DiffusionPrior(nn.Module):
|
|||||||
is_ddim = timesteps < self.noise_scheduler.num_timesteps
|
is_ddim = timesteps < self.noise_scheduler.num_timesteps
|
||||||
|
|
||||||
if not is_ddim:
|
if not is_ddim:
|
||||||
return self.p_sample_loop_ddpm(*args, **kwargs)
|
normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
||||||
|
|
||||||
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
image_embed = normalized_image_embed / self.image_embed_scale
|
||||||
|
return image_embed
|
||||||
|
|
||||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||||
|
|
||||||
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
|
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
|
||||||
|
|
||||||
|
self_cond = None
|
||||||
|
if self.net.self_cond and random.random() < 0.5:
|
||||||
|
with torch.no_grad():
|
||||||
|
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
|
||||||
|
|
||||||
pred = self.net(
|
pred = self.net(
|
||||||
image_embed_noisy,
|
image_embed_noisy,
|
||||||
times,
|
times,
|
||||||
|
self_cond = self_cond,
|
||||||
cond_drop_prob = self.cond_drop_prob,
|
cond_drop_prob = self.cond_drop_prob,
|
||||||
**text_cond
|
**text_cond
|
||||||
)
|
)
|
||||||
@@ -1213,8 +1353,6 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
# retrieve original unscaled image embed
|
# retrieve original unscaled image embed
|
||||||
|
|
||||||
image_embeds /= self.image_embed_scale
|
|
||||||
|
|
||||||
text_embeds = text_cond['text_embed']
|
text_embeds = text_cond['text_embed']
|
||||||
|
|
||||||
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||||
@@ -1313,6 +1451,26 @@ def Downsample(dim, *, dim_out = None):
|
|||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
||||||
|
|
||||||
|
class WeightStandardizedConv2d(nn.Conv2d):
|
||||||
|
"""
|
||||||
|
https://arxiv.org/abs/1903.10520
|
||||||
|
weight standardization purportedly works synergistically with group normalization
|
||||||
|
"""
|
||||||
|
def forward(self, x):
|
||||||
|
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
|
||||||
|
|
||||||
|
weight = self.weight
|
||||||
|
flattened_weights = rearrange(weight, 'o ... -> o (...)')
|
||||||
|
|
||||||
|
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
|
||||||
|
|
||||||
|
var = torch.var(flattened_weights, dim = -1, unbiased = False)
|
||||||
|
var = rearrange(var, 'o -> o 1 1 1')
|
||||||
|
|
||||||
|
weight = (weight - mean) * (var + eps).rsqrt()
|
||||||
|
|
||||||
|
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
class SinusoidalPosEmb(nn.Module):
|
class SinusoidalPosEmb(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1331,10 +1489,13 @@ class Block(nn.Module):
|
|||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
dim_out,
|
dim_out,
|
||||||
groups = 8
|
groups = 8,
|
||||||
|
weight_standardization = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
conv_klass = nn.Conv2d if not weight_standardization else WeightStandardizedConv2d
|
||||||
|
|
||||||
|
self.project = conv_klass(dim, dim_out, 3, padding = 1)
|
||||||
self.norm = nn.GroupNorm(groups, dim_out)
|
self.norm = nn.GroupNorm(groups, dim_out)
|
||||||
self.act = nn.SiLU()
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
@@ -1357,7 +1518,9 @@ class ResnetBlock(nn.Module):
|
|||||||
*,
|
*,
|
||||||
cond_dim = None,
|
cond_dim = None,
|
||||||
time_cond_dim = None,
|
time_cond_dim = None,
|
||||||
groups = 8
|
groups = 8,
|
||||||
|
weight_standardization = False,
|
||||||
|
cosine_sim_cross_attn = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -1377,12 +1540,13 @@ class ResnetBlock(nn.Module):
|
|||||||
'b (h w) c',
|
'b (h w) c',
|
||||||
CrossAttention(
|
CrossAttention(
|
||||||
dim = dim_out,
|
dim = dim_out,
|
||||||
context_dim = cond_dim
|
context_dim = cond_dim,
|
||||||
|
cosine_sim = cosine_sim_cross_attn
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.block1 = Block(dim, dim_out, groups = groups)
|
self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
|
||||||
self.block2 = Block(dim_out, dim_out, groups = groups)
|
self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardization = weight_standardization)
|
||||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, time_emb = None, cond = None):
|
def forward(self, x, time_emb = None, cond = None):
|
||||||
@@ -1412,11 +1576,12 @@ class CrossAttention(nn.Module):
|
|||||||
heads = 8,
|
heads = 8,
|
||||||
dropout = 0.,
|
dropout = 0.,
|
||||||
norm_context = False,
|
norm_context = False,
|
||||||
pb_relax_alpha = 32 ** 2
|
cosine_sim = False,
|
||||||
|
cosine_sim_scale = 16
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pb_relax_alpha = pb_relax_alpha
|
self.cosine_sim = cosine_sim
|
||||||
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
|
self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
|
|
||||||
@@ -1452,7 +1617,10 @@ class CrossAttention(nn.Module):
|
|||||||
k = torch.cat((nk, k), dim = -2)
|
k = torch.cat((nk, k), dim = -2)
|
||||||
v = torch.cat((nv, v), dim = -2)
|
v = torch.cat((nv, v), dim = -2)
|
||||||
|
|
||||||
q = q * self.scale
|
if self.cosine_sim:
|
||||||
|
q, k = map(l2norm, (q, k))
|
||||||
|
|
||||||
|
q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
|
||||||
|
|
||||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||||
max_neg_value = -torch.finfo(sim.dtype).max
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
@@ -1462,10 +1630,7 @@ class CrossAttention(nn.Module):
|
|||||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||||
sim = sim.masked_fill(~mask, max_neg_value)
|
sim = sim.masked_fill(~mask, max_neg_value)
|
||||||
|
|
||||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
||||||
sim = sim * self.pb_relax_alpha
|
|
||||||
|
|
||||||
attn = sim.softmax(dim = -1)
|
|
||||||
|
|
||||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
@@ -1476,7 +1641,8 @@ class LinearAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
dim_head = 32,
|
dim_head = 32,
|
||||||
heads = 8
|
heads = 8,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
@@ -1494,6 +1660,7 @@ class LinearAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(self, fmap):
|
def forward(self, fmap):
|
||||||
h, x, y = self.heads, *fmap.shape[-2:]
|
h, x, y = self.heads, *fmap.shape[-2:]
|
||||||
|
seq_len = x * y
|
||||||
|
|
||||||
fmap = self.norm(fmap)
|
fmap = self.norm(fmap)
|
||||||
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
|
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
|
||||||
@@ -1503,6 +1670,9 @@ class LinearAttention(nn.Module):
|
|||||||
k = k.softmax(dim = -2)
|
k = k.softmax(dim = -2)
|
||||||
|
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
v = l2norm(v)
|
||||||
|
|
||||||
|
k, v = map(lambda t: t / math.sqrt(seq_len), (k, v))
|
||||||
|
|
||||||
context = einsum('b n d, b n e -> b d e', k, v)
|
context = einsum('b n d, b n e -> b d e', k, v)
|
||||||
out = einsum('b n d, b d e -> b n e', q, context)
|
out = einsum('b n d, b d e -> b n e', q, context)
|
||||||
@@ -1538,6 +1708,38 @@ class CrossEmbedLayer(nn.Module):
|
|||||||
fmaps = tuple(map(lambda conv: conv(x), self.convs))
|
fmaps = tuple(map(lambda conv: conv(x), self.convs))
|
||||||
return torch.cat(fmaps, dim = 1)
|
return torch.cat(fmaps, dim = 1)
|
||||||
|
|
||||||
|
class UpsampleCombiner(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
*,
|
||||||
|
enabled = False,
|
||||||
|
dim_ins = tuple(),
|
||||||
|
dim_outs = tuple()
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert len(dim_ins) == len(dim_outs)
|
||||||
|
self.enabled = enabled
|
||||||
|
|
||||||
|
if not self.enabled:
|
||||||
|
self.dim_out = dim
|
||||||
|
return
|
||||||
|
|
||||||
|
self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
|
||||||
|
self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
|
||||||
|
|
||||||
|
def forward(self, x, fmaps = None):
|
||||||
|
target_size = x.shape[-1]
|
||||||
|
|
||||||
|
fmaps = default(fmaps, tuple())
|
||||||
|
|
||||||
|
if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
|
||||||
|
return x
|
||||||
|
|
||||||
|
fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
|
||||||
|
outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
|
||||||
|
return torch.cat((x, *outs), dim = 1)
|
||||||
|
|
||||||
class Unet(nn.Module):
|
class Unet(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -1557,7 +1759,10 @@ class Unet(nn.Module):
|
|||||||
attn_heads = 16,
|
attn_heads = 16,
|
||||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||||
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
|
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
|
||||||
|
self_cond = False, # set this to True to use the self-conditioning technique from - https://arxiv.org/abs/2208.04202
|
||||||
sparse_attn = False,
|
sparse_attn = False,
|
||||||
|
cosine_sim_cross_attn = False,
|
||||||
|
cosine_sim_self_attn = False,
|
||||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||||
cond_on_text_encodings = False,
|
cond_on_text_encodings = False,
|
||||||
max_text_len = 256,
|
max_text_len = 256,
|
||||||
@@ -1566,6 +1771,7 @@ class Unet(nn.Module):
|
|||||||
init_dim = None,
|
init_dim = None,
|
||||||
init_conv_kernel_size = 7,
|
init_conv_kernel_size = 7,
|
||||||
resnet_groups = 8,
|
resnet_groups = 8,
|
||||||
|
resnet_weight_standardization = False,
|
||||||
num_resnet_blocks = 2,
|
num_resnet_blocks = 2,
|
||||||
init_cross_embed = True,
|
init_cross_embed = True,
|
||||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||||
@@ -1575,6 +1781,8 @@ class Unet(nn.Module):
|
|||||||
scale_skip_connection = False,
|
scale_skip_connection = False,
|
||||||
pixel_shuffle_upsample = True,
|
pixel_shuffle_upsample = True,
|
||||||
final_conv_kernel_size = 1,
|
final_conv_kernel_size = 1,
|
||||||
|
combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
|
||||||
|
checkpoint_during_training = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1588,12 +1796,21 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
self.lowres_cond = lowres_cond
|
self.lowres_cond = lowres_cond
|
||||||
|
|
||||||
|
# whether to do self conditioning
|
||||||
|
|
||||||
|
self.self_cond = self_cond
|
||||||
|
|
||||||
# determine dimensions
|
# determine dimensions
|
||||||
|
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.channels_out = default(channels_out, channels)
|
self.channels_out = default(channels_out, channels)
|
||||||
|
|
||||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
# initial number of channels depends on
|
||||||
|
# (1) low resolution conditioning from cascading ddpm paper, conditioned on previous unet output in the cascade
|
||||||
|
# (2) self conditioning (bit diffusion paper)
|
||||||
|
|
||||||
|
init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
|
||||||
|
|
||||||
init_dim = default(init_dim, dim)
|
init_dim = default(init_dim, dim)
|
||||||
|
|
||||||
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
||||||
@@ -1677,7 +1894,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
# attention related params
|
# attention related params
|
||||||
|
|
||||||
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim = cosine_sim_self_attn)
|
||||||
|
|
||||||
self_attn = cast_tuple(self_attn, num_stages)
|
self_attn = cast_tuple(self_attn, num_stages)
|
||||||
|
|
||||||
@@ -1700,9 +1917,13 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
|
upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
|
||||||
|
|
||||||
|
# prepare resnet klass
|
||||||
|
|
||||||
|
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn, weight_standardization = resnet_weight_standardization)
|
||||||
|
|
||||||
# give memory efficient unet an initial resnet block
|
# give memory efficient unet an initial resnet block
|
||||||
|
|
||||||
self.init_resnet_block = ResnetBlock(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
|
self.init_resnet_block = resnet_block(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
|
||||||
|
|
||||||
# layers
|
# layers
|
||||||
|
|
||||||
@@ -1710,7 +1931,8 @@ class Unet(nn.Module):
|
|||||||
self.ups = nn.ModuleList([])
|
self.ups = nn.ModuleList([])
|
||||||
num_resolutions = len(in_out)
|
num_resolutions = len(in_out)
|
||||||
|
|
||||||
skip_connect_dims = [] # keeping track of skip connection dimensions
|
skip_connect_dims = [] # keeping track of skip connection dimensions
|
||||||
|
upsample_combiner_dims = [] # keeping track of dimensions for final upsample feature map combiner
|
||||||
|
|
||||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
|
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
|
||||||
is_first = ind == 0
|
is_first = ind == 0
|
||||||
@@ -1728,17 +1950,17 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
self.downs.append(nn.ModuleList([
|
self.downs.append(nn.ModuleList([
|
||||||
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
|
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
|
||||||
ResnetBlock(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
|
resnet_block(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
|
||||||
nn.ModuleList([ResnetBlock(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
nn.ModuleList([resnet_block(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||||
attention,
|
attention,
|
||||||
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
|
downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
mid_dim = dims[-1]
|
mid_dim = dims[-1]
|
||||||
|
|
||||||
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
self.mid_block1 = resnet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||||
self.mid_attn = create_self_attn(mid_dim)
|
self.mid_attn = create_self_attn(mid_dim)
|
||||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
self.mid_block2 = resnet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||||
|
|
||||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):
|
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):
|
||||||
is_last = ind >= (len(in_out) - 1)
|
is_last = ind >= (len(in_out) - 1)
|
||||||
@@ -1752,14 +1974,27 @@ class Unet(nn.Module):
|
|||||||
elif sparse_attn:
|
elif sparse_attn:
|
||||||
attention = Residual(LinearAttention(dim_out, **attn_kwargs))
|
attention = Residual(LinearAttention(dim_out, **attn_kwargs))
|
||||||
|
|
||||||
|
upsample_combiner_dims.append(dim_out)
|
||||||
|
|
||||||
self.ups.append(nn.ModuleList([
|
self.ups.append(nn.ModuleList([
|
||||||
ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
resnet_block(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||||
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
nn.ModuleList([resnet_block(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||||
attention,
|
attention,
|
||||||
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
|
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
|
||||||
]))
|
]))
|
||||||
|
|
||||||
self.final_resnet_block = ResnetBlock(dim * 2, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
# whether to combine outputs from all upsample blocks for final resnet block
|
||||||
|
|
||||||
|
self.upsample_combiner = UpsampleCombiner(
|
||||||
|
dim = dim,
|
||||||
|
enabled = combine_upsample_fmaps,
|
||||||
|
dim_ins = upsample_combiner_dims,
|
||||||
|
dim_outs = (dim,) * len(upsample_combiner_dims)
|
||||||
|
)
|
||||||
|
|
||||||
|
# a final resnet block
|
||||||
|
|
||||||
|
self.final_resnet_block = resnet_block(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
|
||||||
|
|
||||||
out_dim_in = dim + (channels if lowres_cond else 0)
|
out_dim_in = dim + (channels if lowres_cond else 0)
|
||||||
|
|
||||||
@@ -1767,6 +2002,10 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
|
zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
|
||||||
|
|
||||||
|
# whether to checkpoint during training
|
||||||
|
|
||||||
|
self.checkpoint_during_training = checkpoint_during_training
|
||||||
|
|
||||||
# if the current settings for the unet are not correct
|
# if the current settings for the unet are not correct
|
||||||
# for cascading DDPM, then reinit the unet with the right settings
|
# for cascading DDPM, then reinit the unet with the right settings
|
||||||
def cast_model_parameters(
|
def cast_model_parameters(
|
||||||
@@ -1783,7 +2022,7 @@ class Unet(nn.Module):
|
|||||||
channels == self.channels and \
|
channels == self.channels and \
|
||||||
cond_on_image_embeds == self.cond_on_image_embeds and \
|
cond_on_image_embeds == self.cond_on_image_embeds and \
|
||||||
cond_on_text_encodings == self.cond_on_text_encodings and \
|
cond_on_text_encodings == self.cond_on_text_encodings and \
|
||||||
cond_on_lowres_noise == self.cond_on_lowres_noise and \
|
lowres_noise_cond == self.lowres_noise_cond and \
|
||||||
channels_out == self.channels_out:
|
channels_out == self.channels_out:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -1824,7 +2063,9 @@ class Unet(nn.Module):
|
|||||||
image_cond_drop_prob = 0.,
|
image_cond_drop_prob = 0.,
|
||||||
text_cond_drop_prob = 0.,
|
text_cond_drop_prob = 0.,
|
||||||
blur_sigma = None,
|
blur_sigma = None,
|
||||||
blur_kernel_size = None
|
blur_kernel_size = None,
|
||||||
|
disable_checkpoint = False,
|
||||||
|
self_cond = None
|
||||||
):
|
):
|
||||||
batch_size, device = x.shape[0], x.device
|
batch_size, device = x.shape[0], x.device
|
||||||
|
|
||||||
@@ -1832,6 +2073,14 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
|
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
|
||||||
|
|
||||||
|
# concat self conditioning, if needed
|
||||||
|
|
||||||
|
if self.self_cond:
|
||||||
|
self_cond = default(self_cond, lambda: torch.zeros_like(x))
|
||||||
|
x = torch.cat((x, self_cond), dim = 1)
|
||||||
|
|
||||||
|
# concat low resolution conditioning
|
||||||
|
|
||||||
if exists(lowres_cond_img):
|
if exists(lowres_cond_img):
|
||||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||||
|
|
||||||
@@ -1946,16 +2195,29 @@ class Unet(nn.Module):
|
|||||||
c = self.norm_cond(c)
|
c = self.norm_cond(c)
|
||||||
mid_c = self.norm_mid_cond(mid_c)
|
mid_c = self.norm_mid_cond(mid_c)
|
||||||
|
|
||||||
|
# gradient checkpointing
|
||||||
|
|
||||||
|
can_checkpoint = self.training and self.checkpoint_during_training and not disable_checkpoint
|
||||||
|
apply_checkpoint_fn = make_checkpointable if can_checkpoint else identity
|
||||||
|
|
||||||
|
# make checkpointable modules
|
||||||
|
|
||||||
|
init_resnet_block, mid_block1, mid_attn, mid_block2, final_resnet_block = [maybe(apply_checkpoint_fn)(module) for module in (self.init_resnet_block, self.mid_block1, self.mid_attn, self.mid_block2, self.final_resnet_block)]
|
||||||
|
|
||||||
|
can_checkpoint_cond = lambda m: isinstance(m, ResnetBlock)
|
||||||
|
downs, ups = [maybe(apply_checkpoint_fn)(m, condition = can_checkpoint_cond) for m in (self.downs, self.ups)]
|
||||||
|
|
||||||
# initial resnet block
|
# initial resnet block
|
||||||
|
|
||||||
if exists(self.init_resnet_block):
|
if exists(init_resnet_block):
|
||||||
x = self.init_resnet_block(x, t)
|
x = init_resnet_block(x, t)
|
||||||
|
|
||||||
# go through the layers of the unet, down and up
|
# go through the layers of the unet, down and up
|
||||||
|
|
||||||
hiddens = []
|
down_hiddens = []
|
||||||
|
up_hiddens = []
|
||||||
|
|
||||||
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
|
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in downs:
|
||||||
if exists(pre_downsample):
|
if exists(pre_downsample):
|
||||||
x = pre_downsample(x)
|
x = pre_downsample(x)
|
||||||
|
|
||||||
@@ -1963,24 +2225,24 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
for resnet_block in resnet_blocks:
|
for resnet_block in resnet_blocks:
|
||||||
x = resnet_block(x, t, c)
|
x = resnet_block(x, t, c)
|
||||||
hiddens.append(x)
|
down_hiddens.append(x.contiguous())
|
||||||
|
|
||||||
x = attn(x)
|
x = attn(x)
|
||||||
hiddens.append(x.contiguous())
|
down_hiddens.append(x.contiguous())
|
||||||
|
|
||||||
if exists(post_downsample):
|
if exists(post_downsample):
|
||||||
x = post_downsample(x)
|
x = post_downsample(x)
|
||||||
|
|
||||||
x = self.mid_block1(x, t, mid_c)
|
x = mid_block1(x, t, mid_c)
|
||||||
|
|
||||||
if exists(self.mid_attn):
|
if exists(mid_attn):
|
||||||
x = self.mid_attn(x)
|
x = mid_attn(x)
|
||||||
|
|
||||||
x = self.mid_block2(x, t, mid_c)
|
x = mid_block2(x, t, mid_c)
|
||||||
|
|
||||||
connect_skip = lambda fmap: torch.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim = 1)
|
connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
|
||||||
|
|
||||||
for init_block, resnet_blocks, attn, upsample in self.ups:
|
for init_block, resnet_blocks, attn, upsample in ups:
|
||||||
x = connect_skip(x)
|
x = connect_skip(x)
|
||||||
x = init_block(x, t, c)
|
x = init_block(x, t, c)
|
||||||
|
|
||||||
@@ -1989,11 +2251,15 @@ class Unet(nn.Module):
|
|||||||
x = resnet_block(x, t, c)
|
x = resnet_block(x, t, c)
|
||||||
|
|
||||||
x = attn(x)
|
x = attn(x)
|
||||||
|
|
||||||
|
up_hiddens.append(x.contiguous())
|
||||||
x = upsample(x)
|
x = upsample(x)
|
||||||
|
|
||||||
|
x = self.upsample_combiner(x, up_hiddens)
|
||||||
|
|
||||||
x = torch.cat((x, r), dim = 1)
|
x = torch.cat((x, r), dim = 1)
|
||||||
|
|
||||||
x = self.final_resnet_block(x, t)
|
x = final_resnet_block(x, t)
|
||||||
|
|
||||||
if exists(lowres_cond_img):
|
if exists(lowres_cond_img):
|
||||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||||
@@ -2346,6 +2612,14 @@ class Decoder(nn.Module):
|
|||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
return self.unets[index]
|
return self.unets[index]
|
||||||
|
|
||||||
|
def parse_unet_output(self, learned_variance, output):
|
||||||
|
var_interp_frac_unnormalized = None
|
||||||
|
|
||||||
|
if learned_variance:
|
||||||
|
output, var_interp_frac_unnormalized = output.chunk(2, dim = 1)
|
||||||
|
|
||||||
|
return UnetOutput(output, var_interp_frac_unnormalized)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def one_unet_in_gpu(self, unet_number = None, unet = None):
|
def one_unet_in_gpu(self, unet_number = None, unet = None):
|
||||||
assert exists(unet_number) ^ exists(unet)
|
assert exists(unet_number) ^ exists(unet)
|
||||||
@@ -2384,23 +2658,22 @@ class Decoder(nn.Module):
|
|||||||
x = x.clamp(-s, s) / s
|
x = x.clamp(-s, s) / s
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
||||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||||
|
|
||||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level))
|
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
||||||
|
|
||||||
if learned_variance:
|
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
|
||||||
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
|
|
||||||
|
|
||||||
if predict_x_start:
|
if predict_x_start:
|
||||||
x_recon = pred
|
x_start = pred
|
||||||
else:
|
else:
|
||||||
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||||
|
|
||||||
if clip_denoised:
|
if clip_denoised:
|
||||||
x_recon = self.dynamic_threshold(x_recon)
|
x_start = self.dynamic_threshold(x_start)
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
|
||||||
|
|
||||||
if learned_variance:
|
if learned_variance:
|
||||||
# if learned variance, posterio variance and posterior log variance are predicted by the network
|
# if learned variance, posterio variance and posterior log variance are predicted by the network
|
||||||
@@ -2416,16 +2689,17 @@ class Decoder(nn.Module):
|
|||||||
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
||||||
posterior_variance = posterior_log_variance.exp()
|
posterior_variance = posterior_log_variance.exp()
|
||||||
|
|
||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance, x_start
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
||||||
noise = torch.randn_like(x)
|
noise = torch.randn_like(x)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
return pred, x_start
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_loop_ddpm(
|
def p_sample_loop_ddpm(
|
||||||
@@ -2451,6 +2725,8 @@ class Decoder(nn.Module):
|
|||||||
b = shape[0]
|
b = shape[0]
|
||||||
img = torch.randn(shape, device = device)
|
img = torch.randn(shape, device = device)
|
||||||
|
|
||||||
|
x_start = None # for self-conditioning
|
||||||
|
|
||||||
is_inpaint = exists(inpaint_image)
|
is_inpaint = exists(inpaint_image)
|
||||||
resample_times = inpaint_resample_times if is_inpaint else 1
|
resample_times = inpaint_resample_times if is_inpaint else 1
|
||||||
|
|
||||||
@@ -2478,13 +2754,16 @@ class Decoder(nn.Module):
|
|||||||
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
|
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
|
||||||
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
||||||
|
|
||||||
img = self.p_sample(
|
self_cond = x_start if unet.self_cond else None
|
||||||
|
|
||||||
|
img, x_start = self.p_sample(
|
||||||
unet,
|
unet,
|
||||||
img,
|
img,
|
||||||
times,
|
times,
|
||||||
image_embed = image_embed,
|
image_embed = image_embed,
|
||||||
text_encodings = text_encodings,
|
text_encodings = text_encodings,
|
||||||
cond_scale = cond_scale,
|
cond_scale = cond_scale,
|
||||||
|
self_cond = self_cond,
|
||||||
lowres_cond_img = lowres_cond_img,
|
lowres_cond_img = lowres_cond_img,
|
||||||
lowres_noise_level = lowres_noise_level,
|
lowres_noise_level = lowres_noise_level,
|
||||||
predict_x_start = predict_x_start,
|
predict_x_start = predict_x_start,
|
||||||
@@ -2543,6 +2822,8 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
img = torch.randn(shape, device = device)
|
img = torch.randn(shape, device = device)
|
||||||
|
|
||||||
|
x_start = None # for self-conditioning
|
||||||
|
|
||||||
if not is_latent_diffusion:
|
if not is_latent_diffusion:
|
||||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||||
|
|
||||||
@@ -2563,10 +2844,11 @@ class Decoder(nn.Module):
|
|||||||
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
|
noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
|
||||||
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
|
||||||
|
|
||||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
self_cond = x_start if unet.self_cond else None
|
||||||
|
|
||||||
if learned_variance:
|
unet_output = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||||
pred, _ = pred.chunk(2, dim = 1)
|
|
||||||
|
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
||||||
|
|
||||||
if predict_x_start:
|
if predict_x_start:
|
||||||
x_start = pred
|
x_start = pred
|
||||||
@@ -2623,21 +2905,37 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
|
x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
|
||||||
|
|
||||||
model_output = unet(
|
# unet kwargs
|
||||||
x_noisy,
|
|
||||||
times,
|
unet_kwargs = dict(
|
||||||
image_embed = image_embed,
|
image_embed = image_embed,
|
||||||
text_encodings = text_encodings,
|
text_encodings = text_encodings,
|
||||||
lowres_cond_img = lowres_cond_img,
|
lowres_cond_img = lowres_cond_img,
|
||||||
lowres_noise_level = lowres_noise_level,
|
lowres_noise_level = lowres_noise_level,
|
||||||
|
)
|
||||||
|
|
||||||
|
# self conditioning
|
||||||
|
|
||||||
|
self_cond = None
|
||||||
|
|
||||||
|
if unet.self_cond and random.random() < 0.5:
|
||||||
|
with torch.no_grad():
|
||||||
|
unet_output = unet(x_noisy, times, **unet_kwargs)
|
||||||
|
self_cond, _ = self.parse_unet_output(learned_variance, unet_output)
|
||||||
|
self_cond = self_cond.detach()
|
||||||
|
|
||||||
|
# forward to get model prediction
|
||||||
|
|
||||||
|
unet_output = unet(
|
||||||
|
x_noisy,
|
||||||
|
times,
|
||||||
|
**unet_kwargs,
|
||||||
|
self_cond = self_cond,
|
||||||
image_cond_drop_prob = self.image_cond_drop_prob,
|
image_cond_drop_prob = self.image_cond_drop_prob,
|
||||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||||
)
|
)
|
||||||
|
|
||||||
if learned_variance:
|
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
||||||
pred, _ = model_output.chunk(2, dim = 1)
|
|
||||||
else:
|
|
||||||
pred = model_output
|
|
||||||
|
|
||||||
target = noise if not predict_x_start else x_start
|
target = noise if not predict_x_start else x_start
|
||||||
|
|
||||||
@@ -2660,7 +2958,7 @@ class Decoder(nn.Module):
|
|||||||
# if learning the variance, also include the extra weight kl loss
|
# if learning the variance, also include the extra weight kl loss
|
||||||
|
|
||||||
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = unet_output)
|
||||||
|
|
||||||
# kl loss with detached model predicted mean, for stability reasons as in paper
|
# kl loss with detached model predicted mean, for stability reasons as in paper
|
||||||
|
|
||||||
@@ -2885,7 +3183,7 @@ class DALLE2(nn.Module):
|
|||||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
|
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
|
||||||
|
|
||||||
text_cond = text if self.decoder_need_text_cond else None
|
text_cond = text if self.decoder_need_text_cond else None
|
||||||
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
|
images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)
|
||||||
|
|
||||||
if return_pil_images:
|
if return_pil_images:
|
||||||
images = list(map(self.to_pil, images.unbind(dim = 0)))
|
images = list(map(self.to_pil, images.unbind(dim = 0)))
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from collections.abc import Iterable
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
|
|
||||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||||
@@ -174,20 +174,25 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
diffusion_prior,
|
diffusion_prior,
|
||||||
accelerator,
|
accelerator = None,
|
||||||
use_ema = True,
|
use_ema = True,
|
||||||
lr = 3e-4,
|
lr = 3e-4,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
eps = 1e-6,
|
eps = 1e-6,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
group_wd_params = True,
|
group_wd_params = True,
|
||||||
warmup_steps = 1,
|
warmup_steps = None,
|
||||||
|
cosine_decay_max_steps = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||||
assert isinstance(accelerator, Accelerator)
|
|
||||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||||
|
accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
|
||||||
|
|
||||||
|
if not exists(accelerator):
|
||||||
|
accelerator = Accelerator(**accelerator_kwargs)
|
||||||
|
|
||||||
# assign some helpful member vars
|
# assign some helpful member vars
|
||||||
|
|
||||||
@@ -229,8 +234,11 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
**self.optim_kwargs,
|
**self.optim_kwargs,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
if exists(cosine_decay_max_steps):
|
||||||
|
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
|
||||||
|
else:
|
||||||
|
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
||||||
|
|
||||||
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
|
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
|
||||||
|
|
||||||
@@ -267,6 +275,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
# FIXME: LambdaLR can't be saved due to pickling issues
|
# FIXME: LambdaLR can't be saved due to pickling issues
|
||||||
save_obj = dict(
|
save_obj = dict(
|
||||||
optimizer = self.optimizer.state_dict(),
|
optimizer = self.optimizer.state_dict(),
|
||||||
|
scheduler = self.scheduler.state_dict(),
|
||||||
warmup_scheduler = self.warmup_scheduler,
|
warmup_scheduler = self.warmup_scheduler,
|
||||||
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
|
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
|
||||||
version = version.parse(__version__),
|
version = version.parse(__version__),
|
||||||
@@ -300,7 +309,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
# all processes need to load checkpoint. no restriction here
|
# all processes need to load checkpoint. no restriction here
|
||||||
if isinstance(path_or_state, str):
|
if isinstance(path_or_state, str):
|
||||||
path = Path(path)
|
path = Path(path_or_state)
|
||||||
assert path.exists()
|
assert path.exists()
|
||||||
loaded_obj = torch.load(str(path), map_location=self.device)
|
loaded_obj = torch.load(str(path), map_location=self.device)
|
||||||
|
|
||||||
@@ -313,7 +322,9 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
# unwrap the model when loading from checkpoint
|
# unwrap the model when loading from checkpoint
|
||||||
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
|
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
|
||||||
|
|
||||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||||
|
self.scheduler.load_state_dict(loaded_obj['scheduler'])
|
||||||
|
|
||||||
# set warmupstep
|
# set warmupstep
|
||||||
if exists(self.warmup_scheduler):
|
if exists(self.warmup_scheduler):
|
||||||
@@ -346,7 +357,8 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
||||||
if not self.accelerator.optimizer_step_was_skipped:
|
if not self.accelerator.optimizer_step_was_skipped:
|
||||||
with self.warmup_scheduler.dampening():
|
sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
|
||||||
|
with sched_context():
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
@@ -429,6 +441,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
warmup_steps = None,
|
warmup_steps = None,
|
||||||
|
cosine_decay_max_steps = None,
|
||||||
max_grad_norm = 0.5,
|
max_grad_norm = 0.5,
|
||||||
amp = False,
|
amp = False,
|
||||||
group_wd_params = True,
|
group_wd_params = True,
|
||||||
@@ -450,7 +463,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
# be able to finely customize learning rate, weight decay
|
# be able to finely customize learning rate, weight decay
|
||||||
# per unet
|
# per unet
|
||||||
|
|
||||||
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
|
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
|
||||||
|
|
||||||
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||||
|
|
||||||
@@ -458,7 +471,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
schedulers = []
|
schedulers = []
|
||||||
warmup_schedulers = []
|
warmup_schedulers = []
|
||||||
|
|
||||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
|
||||||
if isinstance(unet, nn.Identity):
|
if isinstance(unet, nn.Identity):
|
||||||
optimizers.append(None)
|
optimizers.append(None)
|
||||||
schedulers.append(None)
|
schedulers.append(None)
|
||||||
@@ -474,7 +487,11 @@ class DecoderTrainer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
optimizers.append(optimizer)
|
optimizers.append(optimizer)
|
||||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
|
||||||
|
if exists(unet_cosine_decay_max_steps):
|
||||||
|
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
|
||||||
|
else:
|
||||||
|
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||||
|
|
||||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||||
warmup_schedulers.append(warmup_scheduler)
|
warmup_schedulers.append(warmup_scheduler)
|
||||||
@@ -554,9 +571,15 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
for ind in range(0, self.num_unets):
|
for ind in range(0, self.num_unets):
|
||||||
optimizer_key = f'optim{ind}'
|
optimizer_key = f'optim{ind}'
|
||||||
|
scheduler_key = f'sched{ind}'
|
||||||
|
|
||||||
optimizer = getattr(self, optimizer_key)
|
optimizer = getattr(self, optimizer_key)
|
||||||
state_dict = optimizer.state_dict() if optimizer is not None else None
|
scheduler = getattr(self, scheduler_key)
|
||||||
save_obj = {**save_obj, optimizer_key: state_dict}
|
|
||||||
|
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
|
||||||
|
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
|
||||||
|
|
||||||
|
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||||
@@ -577,10 +600,18 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
optimizer_key = f'optim{ind}'
|
optimizer_key = f'optim{ind}'
|
||||||
optimizer = getattr(self, optimizer_key)
|
optimizer = getattr(self, optimizer_key)
|
||||||
|
|
||||||
|
scheduler_key = f'sched{ind}'
|
||||||
|
scheduler = getattr(self, scheduler_key)
|
||||||
|
|
||||||
warmup_scheduler = self.warmup_schedulers[ind]
|
warmup_scheduler = self.warmup_schedulers[ind]
|
||||||
if optimizer is not None:
|
|
||||||
|
if exists(optimizer):
|
||||||
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||||
|
|
||||||
|
if exists(scheduler):
|
||||||
|
scheduler.load_state_dict(loaded_obj[scheduler_key])
|
||||||
|
|
||||||
if exists(warmup_scheduler):
|
if exists(warmup_scheduler):
|
||||||
warmup_scheduler.last_step = last_step
|
warmup_scheduler.last_step = last_step
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.0.5'
|
__version__ = '1.8.2'
|
||||||
|
|||||||
Reference in New Issue
Block a user