mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-15 00:34:19 +01:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2c6c91829d | ||
|
|
ad17c69ab6 | ||
|
|
0b4ec34efb | ||
|
|
f027b82e38 | ||
|
|
8cc9016cb0 | ||
|
|
1d8f37befe | ||
|
|
faebf4c8b8 | ||
|
|
b8e8d3c164 | ||
|
|
8e2416b49b | ||
|
|
f37c26e856 | ||
|
|
27a33e1b20 | ||
|
|
6f941a219a |
10
README.md
10
README.md
@@ -1,6 +1,6 @@
|
|||||||
<img src="./dalle2.png" width="450px"></img>
|
<img src="./dalle2.png" width="450px"></img>
|
||||||
|
|
||||||
## DALL-E 2 - Pytorch (wip)
|
## DALL-E 2 - Pytorch
|
||||||
|
|
||||||
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch.
|
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch.
|
||||||
|
|
||||||
@@ -410,11 +410,13 @@ Offer training wrappers
|
|||||||
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
||||||
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
|
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
|
||||||
- [x] add efficient attention in unet
|
- [x] add efficient attention in unet
|
||||||
- [ ] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
|
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
|
||||||
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||||
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
|
- [ ] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||||
|
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
|
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
|
|||||||
@@ -1,2 +1,4 @@
|
|||||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||||
|
|
||||||
|
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||||
from x_clip import CLIP
|
from x_clip import CLIP
|
||||||
|
|||||||
@@ -6,4 +6,4 @@ def main():
|
|||||||
@click.command()
|
@click.command()
|
||||||
@click.argument('text')
|
@click.argument('text')
|
||||||
def dream(text):
|
def dream(text):
|
||||||
return image
|
return 'not ready yet'
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import math
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -463,11 +464,11 @@ class DiffusionPrior(nn.Module):
|
|||||||
net,
|
net,
|
||||||
*,
|
*,
|
||||||
clip,
|
clip,
|
||||||
timesteps=1000,
|
timesteps = 1000,
|
||||||
cond_drop_prob=0.2,
|
cond_drop_prob = 0.2,
|
||||||
loss_type="l1",
|
loss_type = "l1",
|
||||||
predict_x0=True,
|
predict_x0 = True,
|
||||||
beta_schedule="cosine",
|
beta_schedule = "cosine",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
@@ -497,7 +498,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
alphas = 1. - betas
|
alphas = 1. - betas
|
||||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
||||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||||
|
|
||||||
timesteps, = betas.shape
|
timesteps, = betas.shape
|
||||||
@@ -820,16 +821,18 @@ class Unet(nn.Module):
|
|||||||
image_embed_dim,
|
image_embed_dim,
|
||||||
cond_dim = None,
|
cond_dim = None,
|
||||||
num_image_tokens = 4,
|
num_image_tokens = 4,
|
||||||
|
num_time_tokens = 2,
|
||||||
out_dim = None,
|
out_dim = None,
|
||||||
dim_mults=(1, 2, 4, 8),
|
dim_mults=(1, 2, 4, 8),
|
||||||
channels = 3,
|
channels = 3,
|
||||||
|
attn_dim_head = 32,
|
||||||
|
attn_heads = 8,
|
||||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||||
lowres_cond_upsample_mode = 'bilinear',
|
|
||||||
blur_sigma = 0.1,
|
|
||||||
blur_kernel_size = 3,
|
|
||||||
sparse_attn = False,
|
sparse_attn = False,
|
||||||
sparse_attn_window = 8, # window size for sparse attention
|
sparse_attn_window = 8, # window size for sparse attention
|
||||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||||
|
cond_on_text_encodings = False,
|
||||||
|
cond_on_image_embeds = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# save locals to take care of some hyperparameters for cascading DDPM
|
# save locals to take care of some hyperparameters for cascading DDPM
|
||||||
@@ -841,9 +844,6 @@ class Unet(nn.Module):
|
|||||||
# for eventual cascading diffusion
|
# for eventual cascading diffusion
|
||||||
|
|
||||||
self.lowres_cond = lowres_cond
|
self.lowres_cond = lowres_cond
|
||||||
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
|
|
||||||
self.lowres_blur_kernel_size = blur_kernel_size
|
|
||||||
self.lowres_blur_sigma = blur_sigma
|
|
||||||
|
|
||||||
# determine dimensions
|
# determine dimensions
|
||||||
|
|
||||||
@@ -862,8 +862,8 @@ class Unet(nn.Module):
|
|||||||
SinusoidalPosEmb(dim),
|
SinusoidalPosEmb(dim),
|
||||||
nn.Linear(dim, dim * 4),
|
nn.Linear(dim, dim * 4),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(dim * 4, cond_dim),
|
nn.Linear(dim * 4, cond_dim * num_time_tokens),
|
||||||
Rearrange('b d -> b 1 d')
|
Rearrange('b (r d) -> b r d', r = num_time_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.image_to_cond = nn.Sequential(
|
self.image_to_cond = nn.Sequential(
|
||||||
@@ -873,11 +873,21 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
self.text_to_cond = nn.LazyLinear(cond_dim)
|
self.text_to_cond = nn.LazyLinear(cond_dim)
|
||||||
|
|
||||||
|
# finer control over whether to condition on image embeddings and text encodings
|
||||||
|
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
||||||
|
|
||||||
|
self.cond_on_text_encodings = cond_on_text_encodings
|
||||||
|
self.cond_on_image_embeds = cond_on_image_embeds
|
||||||
|
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
|
|
||||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||||
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
||||||
|
|
||||||
|
# attention related params
|
||||||
|
|
||||||
|
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||||
|
|
||||||
# layers
|
# layers
|
||||||
|
|
||||||
self.downs = nn.ModuleList([])
|
self.downs = nn.ModuleList([])
|
||||||
@@ -891,7 +901,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
self.downs.append(nn.ModuleList([
|
self.downs.append(nn.ModuleList([
|
||||||
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
||||||
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
|
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
|
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
|
||||||
Downsample(dim_out) if not is_last else nn.Identity()
|
Downsample(dim_out) if not is_last else nn.Identity()
|
||||||
]))
|
]))
|
||||||
@@ -899,7 +909,7 @@ class Unet(nn.Module):
|
|||||||
mid_dim = dims[-1]
|
mid_dim = dims[-1]
|
||||||
|
|
||||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None
|
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||||
|
|
||||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||||
@@ -908,7 +918,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
self.ups.append(nn.ModuleList([
|
self.ups.append(nn.ModuleList([
|
||||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
|
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
|
||||||
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
|
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
|
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
|
||||||
Upsample(dim_in)
|
Upsample(dim_in)
|
||||||
]))
|
]))
|
||||||
@@ -961,13 +971,6 @@ class Unet(nn.Module):
|
|||||||
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
|
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
|
||||||
|
|
||||||
if exists(lowres_cond_img):
|
if exists(lowres_cond_img):
|
||||||
if self.training:
|
|
||||||
# when training, blur the low resolution conditional image
|
|
||||||
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
|
|
||||||
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
|
|
||||||
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
|
||||||
|
|
||||||
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
|
|
||||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||||
|
|
||||||
# time conditioning
|
# time conditioning
|
||||||
@@ -982,17 +985,22 @@ class Unet(nn.Module):
|
|||||||
# mask out image embedding depending on condition dropout
|
# mask out image embedding depending on condition dropout
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
|
|
||||||
image_tokens = self.image_to_cond(image_embed)
|
image_tokens = None
|
||||||
|
|
||||||
image_tokens = torch.where(
|
if self.cond_on_image_embeds:
|
||||||
cond_prob_mask,
|
image_tokens = self.image_to_cond(image_embed)
|
||||||
image_tokens,
|
|
||||||
self.null_image_embed
|
image_tokens = torch.where(
|
||||||
)
|
cond_prob_mask,
|
||||||
|
image_tokens,
|
||||||
|
self.null_image_embed
|
||||||
|
)
|
||||||
|
|
||||||
# take care of text encodings (optional)
|
# take care of text encodings (optional)
|
||||||
|
|
||||||
if exists(text_encodings):
|
text_tokens = None
|
||||||
|
|
||||||
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||||
text_tokens = self.text_to_cond(text_encodings)
|
text_tokens = self.text_to_cond(text_encodings)
|
||||||
text_tokens = torch.where(
|
text_tokens = torch.where(
|
||||||
cond_prob_mask,
|
cond_prob_mask,
|
||||||
@@ -1002,12 +1010,15 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
# main conditioning tokens (c)
|
# main conditioning tokens (c)
|
||||||
|
|
||||||
c = torch.cat((time_tokens, image_tokens), dim = -2)
|
c = time_tokens
|
||||||
|
|
||||||
|
if exists(image_tokens):
|
||||||
|
c = torch.cat((c, image_tokens), dim = -2)
|
||||||
|
|
||||||
# text and image conditioning tokens (mid_c)
|
# text and image conditioning tokens (mid_c)
|
||||||
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
|
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
|
||||||
|
|
||||||
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
|
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
|
||||||
|
|
||||||
# go through the layers of the unet, down and up
|
# go through the layers of the unet, down and up
|
||||||
|
|
||||||
@@ -1036,6 +1047,44 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
return self.final_conv(x)
|
return self.final_conv(x)
|
||||||
|
|
||||||
|
class LowresConditioner(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cond_upsample_mode = 'bilinear',
|
||||||
|
downsample_first = True,
|
||||||
|
blur_sigma = 0.1,
|
||||||
|
blur_kernel_size = 3,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.cond_upsample_mode = cond_upsample_mode
|
||||||
|
self.downsample_first = downsample_first
|
||||||
|
self.blur_sigma = blur_sigma
|
||||||
|
self.blur_kernel_size = blur_kernel_size
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
cond_fmap,
|
||||||
|
*,
|
||||||
|
target_image_size,
|
||||||
|
downsample_image_size = None,
|
||||||
|
blur_sigma = None,
|
||||||
|
blur_kernel_size = None
|
||||||
|
):
|
||||||
|
target_image_size = cast_tuple(target_image_size, 2)
|
||||||
|
|
||||||
|
if self.training and self.downsample_first and exists(downsample_image_size):
|
||||||
|
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||||
|
|
||||||
|
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
# when training, blur the low resolution conditional image
|
||||||
|
blur_sigma = default(blur_sigma, self.blur_sigma)
|
||||||
|
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
||||||
|
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||||
|
|
||||||
|
return cond_fmap
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -1046,7 +1095,11 @@ class Decoder(nn.Module):
|
|||||||
cond_drop_prob = 0.2,
|
cond_drop_prob = 0.2,
|
||||||
loss_type = 'l1',
|
loss_type = 'l1',
|
||||||
beta_schedule = 'cosine',
|
beta_schedule = 'cosine',
|
||||||
image_sizes = None # for cascading ddpm, image size at each stage
|
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||||
|
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
|
||||||
|
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||||
|
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||||
|
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
@@ -1071,12 +1124,26 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||||
self.image_sizes = image_sizes
|
self.image_sizes = image_sizes
|
||||||
|
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
|
||||||
|
|
||||||
|
# cascading ddpm related stuff
|
||||||
|
|
||||||
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
||||||
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
||||||
|
|
||||||
|
self.to_lowres_cond = LowresConditioner(
|
||||||
|
cond_upsample_mode = lowres_cond_upsample_mode,
|
||||||
|
downsample_first = lowres_downsample_first,
|
||||||
|
blur_sigma = blur_sigma,
|
||||||
|
blur_kernel_size = blur_kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# classifier free guidance
|
||||||
|
|
||||||
self.cond_drop_prob = cond_drop_prob
|
self.cond_drop_prob = cond_drop_prob
|
||||||
|
|
||||||
|
# noise schedule
|
||||||
|
|
||||||
if beta_schedule == "cosine":
|
if beta_schedule == "cosine":
|
||||||
betas = cosine_beta_schedule(timesteps)
|
betas = cosine_beta_schedule(timesteps)
|
||||||
elif beta_schedule == "linear":
|
elif beta_schedule == "linear":
|
||||||
@@ -1091,7 +1158,7 @@ class Decoder(nn.Module):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
alphas = 1. - betas
|
alphas = 1. - betas
|
||||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
||||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||||
|
|
||||||
timesteps, = betas.shape
|
timesteps, = betas.shape
|
||||||
@@ -1124,6 +1191,25 @@ class Decoder(nn.Module):
|
|||||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||||
|
|
||||||
|
def get_unet(self, unet_number):
|
||||||
|
assert 0 < unet_number <= len(self.unets)
|
||||||
|
index = unet_number - 1
|
||||||
|
return self.unets[index]
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def one_unet_in_gpu(self, unet_number = None, unet = None):
|
||||||
|
assert exists(unet_number) ^ exists(unet)
|
||||||
|
|
||||||
|
if exists(unet_number):
|
||||||
|
unet = self.get_unet(unet_number)
|
||||||
|
|
||||||
|
self.cuda()
|
||||||
|
self.unets.cpu()
|
||||||
|
|
||||||
|
unet.cuda()
|
||||||
|
yield
|
||||||
|
unet.cpu()
|
||||||
|
|
||||||
def get_text_encodings(self, text):
|
def get_text_encodings(self, text):
|
||||||
text_encodings = self.clip.text_transformer(text)
|
text_encodings = self.clip.text_transformer(text)
|
||||||
return text_encodings[:, 1:]
|
return text_encodings[:, 1:]
|
||||||
@@ -1184,6 +1270,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||||
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise=None):
|
def q_sample(self, x_start, t, noise=None):
|
||||||
@@ -1223,25 +1310,37 @@ class Decoder(nn.Module):
|
|||||||
@eval_decorator
|
@eval_decorator
|
||||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||||
batch_size = image_embed.shape[0]
|
batch_size = image_embed.shape[0]
|
||||||
channels = self.channels
|
|
||||||
|
|
||||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||||
|
|
||||||
img = None
|
img = None
|
||||||
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
|
|
||||||
shape = (batch_size, channels, image_size, image_size)
|
for unet, channel, image_size in tqdm(zip(self.unets, self.sample_channels, self.image_sizes)):
|
||||||
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
|
with self.one_unet_in_gpu(unet = unet):
|
||||||
|
lowres_cond_img = self.to_lowres_cond(
|
||||||
|
img,
|
||||||
|
target_image_size = image_size
|
||||||
|
) if unet.lowres_cond else None
|
||||||
|
|
||||||
|
img = self.p_sample_loop(
|
||||||
|
unet,
|
||||||
|
(batch_size, channel, image_size, image_size),
|
||||||
|
image_embed = image_embed,
|
||||||
|
text_encodings = text_encodings,
|
||||||
|
cond_scale = cond_scale,
|
||||||
|
lowres_cond_img = lowres_cond_img
|
||||||
|
)
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
|
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
|
||||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||||
unet_number = default(unet_number, 1)
|
unet_number = default(unet_number, 1)
|
||||||
assert 1 <= unet_number <= len(self.unets)
|
unet_index = unet_number - 1
|
||||||
|
|
||||||
index = unet_number - 1
|
unet = self.get_unet(unet_number)
|
||||||
unet = self.unets[index]
|
|
||||||
target_image_size = self.image_sizes[index]
|
target_image_size = self.image_sizes[unet_index]
|
||||||
|
|
||||||
b, c, h, w, device, = *image.shape, image.device
|
b, c, h, w, device, = *image.shape, image.device
|
||||||
|
|
||||||
@@ -1255,7 +1354,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||||
|
|
||||||
lowres_cond_img = image if index > 0 else None
|
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||||
ddpm_image = resize_image_to(image, target_image_size)
|
ddpm_image = resize_image_to(image, target_image_size)
|
||||||
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn, einsum
|
|
||||||
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
class LatentDiffusion(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x
|
|
||||||
532
dalle2_pytorch/vqgan_vae.py
Normal file
532
dalle2_pytorch/vqgan_vae.py
Normal file
@@ -0,0 +1,532 @@
|
|||||||
|
import copy
|
||||||
|
import math
|
||||||
|
from math import sqrt
|
||||||
|
from functools import partial, wraps
|
||||||
|
|
||||||
|
from vector_quantize_pytorch import VectorQuantize as VQ
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn, einsum
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.autograd import grad as torch_grad
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
from einops import rearrange, reduce, repeat
|
||||||
|
|
||||||
|
# constants
|
||||||
|
|
||||||
|
MList = nn.ModuleList
|
||||||
|
|
||||||
|
# helper functions
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
# decorators
|
||||||
|
|
||||||
|
def eval_decorator(fn):
|
||||||
|
def inner(model, *args, **kwargs):
|
||||||
|
was_training = model.training
|
||||||
|
model.eval()
|
||||||
|
out = fn(model, *args, **kwargs)
|
||||||
|
model.train(was_training)
|
||||||
|
return out
|
||||||
|
return inner
|
||||||
|
|
||||||
|
def remove_vgg(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(self, *args, **kwargs):
|
||||||
|
has_vgg = hasattr(self, 'vgg')
|
||||||
|
if has_vgg:
|
||||||
|
vgg = self.vgg
|
||||||
|
delattr(self, 'vgg')
|
||||||
|
|
||||||
|
out = fn(self, *args, **kwargs)
|
||||||
|
|
||||||
|
if has_vgg:
|
||||||
|
self.vgg = vgg
|
||||||
|
|
||||||
|
return out
|
||||||
|
return inner
|
||||||
|
|
||||||
|
# keyword argument helpers
|
||||||
|
|
||||||
|
def pick_and_pop(keys, d):
|
||||||
|
values = list(map(lambda key: d.pop(key), keys))
|
||||||
|
return dict(zip(keys, values))
|
||||||
|
|
||||||
|
def group_dict_by_key(cond, d):
|
||||||
|
return_val = [dict(),dict()]
|
||||||
|
for key in d.keys():
|
||||||
|
match = bool(cond(key))
|
||||||
|
ind = int(not match)
|
||||||
|
return_val[ind][key] = d[key]
|
||||||
|
return (*return_val,)
|
||||||
|
|
||||||
|
def string_begins_with(prefix, str):
|
||||||
|
return str.startswith(prefix)
|
||||||
|
|
||||||
|
def group_by_key_prefix(prefix, d):
|
||||||
|
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||||
|
|
||||||
|
def groupby_prefix_and_trim(prefix, d):
|
||||||
|
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||||
|
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||||
|
return kwargs_without_prefix, kwargs
|
||||||
|
|
||||||
|
# tensor helper functions
|
||||||
|
|
||||||
|
def log(t, eps = 1e-10):
|
||||||
|
return torch.log(t + eps)
|
||||||
|
|
||||||
|
def gradient_penalty(images, output, weight = 10):
|
||||||
|
batch_size = images.shape[0]
|
||||||
|
gradients = torch_grad(outputs = output, inputs = images,
|
||||||
|
grad_outputs = torch.ones(output.size(), device = images.device),
|
||||||
|
create_graph = True, retain_graph = True, only_inputs = True)[0]
|
||||||
|
|
||||||
|
gradients = rearrange(gradients, 'b ... -> b (...)')
|
||||||
|
return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
|
||||||
|
|
||||||
|
def l2norm(t):
|
||||||
|
return F.normalize(t, dim = -1)
|
||||||
|
|
||||||
|
def leaky_relu(p = 0.1):
|
||||||
|
return nn.LeakyReLU(0.1)
|
||||||
|
|
||||||
|
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
|
||||||
|
t = t / alpha
|
||||||
|
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
|
||||||
|
return (t * alpha).softmax(dim = dim)
|
||||||
|
|
||||||
|
def safe_div(numer, denom, eps = 1e-8):
|
||||||
|
return numer / (denom + eps)
|
||||||
|
|
||||||
|
# gan losses
|
||||||
|
|
||||||
|
def hinge_discr_loss(fake, real):
|
||||||
|
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
|
||||||
|
|
||||||
|
def hinge_gen_loss(fake):
|
||||||
|
return -fake.mean()
|
||||||
|
|
||||||
|
def bce_discr_loss(fake, real):
|
||||||
|
return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()
|
||||||
|
|
||||||
|
def bce_gen_loss(fake):
|
||||||
|
return -log(torch.sigmoid(fake)).mean()
|
||||||
|
|
||||||
|
def grad_layer_wrt_loss(loss, layer):
|
||||||
|
return torch_grad(
|
||||||
|
outputs = loss,
|
||||||
|
inputs = layer,
|
||||||
|
grad_outputs = torch.ones_like(loss),
|
||||||
|
retain_graph = True
|
||||||
|
)[0].detach()
|
||||||
|
|
||||||
|
# vqgan vae
|
||||||
|
|
||||||
|
class LayerNormChan(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
eps = 1e-5
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||||
|
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||||
|
return (x - mean) / (var + self.eps).sqrt() * self.gamma
|
||||||
|
|
||||||
|
class Discriminator(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims,
|
||||||
|
channels = 3,
|
||||||
|
groups = 16,
|
||||||
|
init_kernel_size = 5
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
dim_pairs = zip(dims[:-1], dims[1:])
|
||||||
|
|
||||||
|
self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])
|
||||||
|
|
||||||
|
for dim_in, dim_out in dim_pairs:
|
||||||
|
self.layers.append(nn.Sequential(
|
||||||
|
nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
|
||||||
|
nn.GroupNorm(groups, dim_out),
|
||||||
|
leaky_relu()
|
||||||
|
))
|
||||||
|
|
||||||
|
dim = dims[-1]
|
||||||
|
self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
|
||||||
|
nn.Conv2d(dim, dim, 1),
|
||||||
|
leaky_relu(),
|
||||||
|
nn.Conv2d(dim, 1, 4)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for net in self.layers:
|
||||||
|
x = net(x)
|
||||||
|
|
||||||
|
return self.to_logits(x)
|
||||||
|
|
||||||
|
class ContinuousPositionBias(nn.Module):
|
||||||
|
""" from https://arxiv.org/abs/2111.09883 """
|
||||||
|
|
||||||
|
def __init__(self, *, dim, heads, layers = 2):
|
||||||
|
super().__init__()
|
||||||
|
self.net = MList([])
|
||||||
|
self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))
|
||||||
|
|
||||||
|
for _ in range(layers - 1):
|
||||||
|
self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))
|
||||||
|
|
||||||
|
self.net.append(nn.Linear(dim, heads))
|
||||||
|
self.register_buffer('rel_pos', None, persistent = False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
n, device = x.shape[-1], x.device
|
||||||
|
fmap_size = int(sqrt(n))
|
||||||
|
|
||||||
|
if not exists(self.rel_pos):
|
||||||
|
pos = torch.arange(fmap_size, device = device)
|
||||||
|
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
||||||
|
grid = rearrange(grid, 'c i j -> (i j) c')
|
||||||
|
rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
|
||||||
|
rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
|
||||||
|
self.register_buffer('rel_pos', rel_pos, persistent = False)
|
||||||
|
|
||||||
|
rel_pos = self.rel_pos.float()
|
||||||
|
|
||||||
|
for layer in self.net:
|
||||||
|
rel_pos = layer(rel_pos)
|
||||||
|
|
||||||
|
bias = rearrange(rel_pos, 'i j h -> h i j')
|
||||||
|
return x + bias
|
||||||
|
|
||||||
|
class GLUResBlock(nn.Module):
|
||||||
|
def __init__(self, chan, groups = 16):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Conv2d(chan, chan * 2, 3, padding = 1),
|
||||||
|
nn.GLU(dim = 1),
|
||||||
|
nn.GroupNorm(groups, chan),
|
||||||
|
nn.Conv2d(chan, chan * 2, 3, padding = 1),
|
||||||
|
nn.GLU(dim = 1),
|
||||||
|
nn.GroupNorm(groups, chan),
|
||||||
|
nn.Conv2d(chan, chan, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x) + x
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(self, chan, groups = 16):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Conv2d(chan, chan, 3, padding = 1),
|
||||||
|
nn.GroupNorm(groups, chan),
|
||||||
|
leaky_relu(),
|
||||||
|
nn.Conv2d(chan, chan, 3, padding = 1),
|
||||||
|
nn.GroupNorm(groups, chan),
|
||||||
|
leaky_relu(),
|
||||||
|
nn.Conv2d(chan, chan, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x) + x
|
||||||
|
|
||||||
|
class VQGanAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dim,
|
||||||
|
dim_head = 64,
|
||||||
|
heads = 8,
|
||||||
|
dropout = 0.
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
inner_dim = heads * dim_head
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.pre_norm = LayerNormChan(dim)
|
||||||
|
|
||||||
|
self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)
|
||||||
|
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||||
|
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.heads
|
||||||
|
height, width, residual = *x.shape[-2:], x.clone()
|
||||||
|
|
||||||
|
x = self.pre_norm(x)
|
||||||
|
|
||||||
|
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))
|
||||||
|
|
||||||
|
sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale
|
||||||
|
|
||||||
|
sim = self.cpb(sim)
|
||||||
|
|
||||||
|
attn = stable_softmax(sim, dim = -1)
|
||||||
|
attn = self.dropout(attn)
|
||||||
|
|
||||||
|
out = einsum('b h i j, b h c j -> b h c i', attn, v)
|
||||||
|
out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)
|
||||||
|
out = self.to_out(out)
|
||||||
|
|
||||||
|
return out + residual
|
||||||
|
|
||||||
|
class VQGanVAE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dim,
|
||||||
|
image_size,
|
||||||
|
channels = 3,
|
||||||
|
num_layers = 4,
|
||||||
|
layer_mults = None,
|
||||||
|
l2_recon_loss = False,
|
||||||
|
use_hinge_loss = True,
|
||||||
|
num_resnet_blocks = 1,
|
||||||
|
vgg = None,
|
||||||
|
vq_codebook_size = 512,
|
||||||
|
vq_decay = 0.8,
|
||||||
|
vq_commitment_weight = 1.,
|
||||||
|
vq_kmeans_init = True,
|
||||||
|
vq_use_cosine_sim = True,
|
||||||
|
use_attn = True,
|
||||||
|
attn_dim_head = 64,
|
||||||
|
attn_heads = 8,
|
||||||
|
resnet_groups = 16,
|
||||||
|
attn_dropout = 0.,
|
||||||
|
first_conv_kernel_size = 5,
|
||||||
|
use_vgg_and_gan = True,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
|
||||||
|
|
||||||
|
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
|
||||||
|
|
||||||
|
self.image_size = image_size
|
||||||
|
self.channels = channels
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.fmap_size = image_size // (num_layers ** 2)
|
||||||
|
self.codebook_size = vq_codebook_size
|
||||||
|
|
||||||
|
self.encoders = MList([])
|
||||||
|
self.decoders = MList([])
|
||||||
|
|
||||||
|
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(num_layers))))
|
||||||
|
assert len(layer_mults) == num_layers, 'layer multipliers must be equal to designated number of layers'
|
||||||
|
|
||||||
|
layer_dims = [dim * mult for mult in layer_mults]
|
||||||
|
dims = (dim, *layer_dims)
|
||||||
|
codebook_dim = layer_dims[-1]
|
||||||
|
|
||||||
|
dim_pairs = zip(dims[:-1], dims[1:])
|
||||||
|
|
||||||
|
append = lambda arr, t: arr.append(t)
|
||||||
|
prepend = lambda arr, t: arr.insert(0, t)
|
||||||
|
|
||||||
|
if not isinstance(num_resnet_blocks, tuple):
|
||||||
|
num_resnet_blocks = (*((0,) * (num_layers - 1)), num_resnet_blocks)
|
||||||
|
|
||||||
|
if not isinstance(use_attn, tuple):
|
||||||
|
use_attn = (*((False,) * (num_layers - 1)), use_attn)
|
||||||
|
|
||||||
|
assert len(num_resnet_blocks) == num_layers, 'number of resnet blocks config must be equal to number of layers'
|
||||||
|
assert len(use_attn) == num_layers
|
||||||
|
|
||||||
|
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(num_layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||||
|
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||||
|
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
|
||||||
|
|
||||||
|
if layer_use_attn:
|
||||||
|
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||||
|
|
||||||
|
for _ in range(layer_num_resnet_blocks):
|
||||||
|
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
|
||||||
|
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
|
||||||
|
|
||||||
|
if layer_use_attn:
|
||||||
|
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||||
|
|
||||||
|
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
||||||
|
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
||||||
|
|
||||||
|
self.vq = VQ(
|
||||||
|
dim = codebook_dim,
|
||||||
|
codebook_size = vq_codebook_size,
|
||||||
|
decay = vq_decay,
|
||||||
|
commitment_weight = vq_commitment_weight,
|
||||||
|
accept_image_fmap = True,
|
||||||
|
kmeans_init = vq_kmeans_init,
|
||||||
|
use_cosine_sim = vq_use_cosine_sim,
|
||||||
|
**vq_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# reconstruction loss
|
||||||
|
|
||||||
|
self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss
|
||||||
|
|
||||||
|
# turn off GAN and perceptual loss if grayscale
|
||||||
|
|
||||||
|
self.vgg = None
|
||||||
|
self.discr = None
|
||||||
|
self.use_vgg_and_gan = use_vgg_and_gan
|
||||||
|
|
||||||
|
if not use_vgg_and_gan:
|
||||||
|
return
|
||||||
|
|
||||||
|
# preceptual loss
|
||||||
|
|
||||||
|
if exists(vgg):
|
||||||
|
self.vgg = vgg
|
||||||
|
else:
|
||||||
|
self.vgg = torchvision.models.vgg16(pretrained = True)
|
||||||
|
self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])
|
||||||
|
|
||||||
|
# gan related losses
|
||||||
|
|
||||||
|
self.discr = Discriminator(dims = dims, channels = channels)
|
||||||
|
|
||||||
|
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
|
||||||
|
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
|
||||||
|
|
||||||
|
def copy_for_eval(self):
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
vae_copy = copy.deepcopy(self.cpu())
|
||||||
|
|
||||||
|
if vae_copy.use_vgg_and_gan:
|
||||||
|
del vae_copy.discr
|
||||||
|
del vae_copy.vgg
|
||||||
|
|
||||||
|
vae_copy.eval()
|
||||||
|
return vae_copy.to(device)
|
||||||
|
|
||||||
|
@remove_vgg
|
||||||
|
def state_dict(self, *args, **kwargs):
|
||||||
|
return super().state_dict(*args, **kwargs)
|
||||||
|
|
||||||
|
@remove_vgg
|
||||||
|
def load_state_dict(self, *args, **kwargs):
|
||||||
|
return super().load_state_dict(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def codebook(self):
|
||||||
|
return self.vq.codebook
|
||||||
|
|
||||||
|
def encode(self, fmap):
|
||||||
|
for enc in self.encoders:
|
||||||
|
fmap = enc(fmap)
|
||||||
|
|
||||||
|
return fmap
|
||||||
|
|
||||||
|
def decode(self, fmap):
|
||||||
|
fmap = self.vq(fmap)
|
||||||
|
|
||||||
|
for dec in self.decoders:
|
||||||
|
fmap = dec(fmap)
|
||||||
|
|
||||||
|
return fmap
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
img,
|
||||||
|
return_loss = False,
|
||||||
|
return_discr_loss = False,
|
||||||
|
return_recons = False
|
||||||
|
):
|
||||||
|
batch, channels, height, width, device = *img.shape, img.device
|
||||||
|
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
||||||
|
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
|
||||||
|
|
||||||
|
fmap, indices, commit_loss = self.encode(img)
|
||||||
|
|
||||||
|
fmap = self.decode(fmap)
|
||||||
|
|
||||||
|
if not return_loss and not return_discr_loss:
|
||||||
|
return fmap
|
||||||
|
|
||||||
|
assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'
|
||||||
|
|
||||||
|
# whether to return discriminator loss
|
||||||
|
|
||||||
|
if return_discr_loss:
|
||||||
|
assert exists(self.discr), 'discriminator must exist to train it'
|
||||||
|
|
||||||
|
fmap.detach_()
|
||||||
|
img.requires_grad_()
|
||||||
|
|
||||||
|
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
|
||||||
|
|
||||||
|
gp = gradient_penalty(img, img_discr_logits)
|
||||||
|
|
||||||
|
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
|
||||||
|
|
||||||
|
loss = discr_loss + gp
|
||||||
|
|
||||||
|
if return_recons:
|
||||||
|
return loss, fmap
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
# reconstruction loss
|
||||||
|
|
||||||
|
recon_loss = self.recon_loss_fn(fmap, img)
|
||||||
|
|
||||||
|
# early return if training on grayscale
|
||||||
|
|
||||||
|
if not self.use_vgg_and_gan:
|
||||||
|
if return_recons:
|
||||||
|
return recon_loss, fmap
|
||||||
|
|
||||||
|
return recon_loss
|
||||||
|
|
||||||
|
# perceptual loss
|
||||||
|
|
||||||
|
img_vgg_input = img
|
||||||
|
fmap_vgg_input = fmap
|
||||||
|
|
||||||
|
if img.shape[1] == 1:
|
||||||
|
# handle grayscale for vgg
|
||||||
|
img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))
|
||||||
|
|
||||||
|
img_vgg_feats = self.vgg(img_vgg_input)
|
||||||
|
recon_vgg_feats = self.vgg(fmap_vgg_input)
|
||||||
|
perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
|
||||||
|
|
||||||
|
# generator loss
|
||||||
|
|
||||||
|
gen_loss = self.gen_loss(self.discr(fmap))
|
||||||
|
|
||||||
|
# calculate adaptive weight
|
||||||
|
|
||||||
|
last_dec_layer = self.decoders[-1].weight
|
||||||
|
|
||||||
|
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
|
||||||
|
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
|
||||||
|
|
||||||
|
adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
|
||||||
|
adaptive_weight.clamp_(max = 1e4)
|
||||||
|
|
||||||
|
# combine losses
|
||||||
|
|
||||||
|
loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
|
||||||
|
|
||||||
|
if return_recons:
|
||||||
|
return loss, fmap
|
||||||
|
|
||||||
|
return loss
|
||||||
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.0.26',
|
version = '0.0.32',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
@@ -30,6 +30,7 @@ setup(
|
|||||||
'torch>=1.10',
|
'torch>=1.10',
|
||||||
'torchvision',
|
'torchvision',
|
||||||
'tqdm',
|
'tqdm',
|
||||||
|
'vector-quantize-pytorch',
|
||||||
'x-clip>=0.4.4',
|
'x-clip>=0.4.4',
|
||||||
'youtokentome'
|
'youtokentome'
|
||||||
],
|
],
|
||||||
|
|||||||
Reference in New Issue
Block a user