mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 13:54:21 +01:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8b9bbec7d1 | ||
|
|
1bb9fc9829 | ||
|
|
5e421bd5bb | ||
|
|
67fcab1122 | ||
|
|
5bfbccda22 | ||
|
|
989275ff59 | ||
|
|
56408f4a40 | ||
|
|
d1a697ac23 | ||
|
|
ebe01749ed |
13
README.md
13
README.md
@@ -760,7 +760,7 @@ decoder = Decoder(
|
|||||||
unet = (unet1, unet2),
|
unet = (unet1, unet2),
|
||||||
image_sizes = (128, 256),
|
image_sizes = (128, 256),
|
||||||
clip = clip,
|
clip = clip,
|
||||||
timesteps = 1,
|
timesteps = 1000,
|
||||||
condition_on_text_encodings = True
|
condition_on_text_encodings = True
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
@@ -778,6 +778,12 @@ for unet_number in (1, 2):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||||
|
|
||||||
|
# after much training
|
||||||
|
# you can sample from the exponentially moving averaged unets as so
|
||||||
|
|
||||||
|
mock_image_embed = torch.randn(4, 512).cuda()
|
||||||
|
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||||
```
|
```
|
||||||
|
|
||||||
## CLI (wip)
|
## CLI (wip)
|
||||||
@@ -812,14 +818,15 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||||
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
||||||
- [x] take care of mixed precision as well as gradient accumulation within decoder trainer
|
- [x] take care of mixed precision as well as gradient accumulation within decoder trainer
|
||||||
|
- [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
|
||||||
|
- [x] bring in tools to train vqgan-vae
|
||||||
|
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
- [ ] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
|
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||||
- [ ] bring in tools to train vqgan-vae
|
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
|
|||||||
@@ -922,6 +922,7 @@ class ConvNextBlock(nn.Module):
|
|||||||
dim_out,
|
dim_out,
|
||||||
*,
|
*,
|
||||||
cond_dim = None,
|
cond_dim = None,
|
||||||
|
time_cond_dim = None,
|
||||||
mult = 2,
|
mult = 2,
|
||||||
norm = True
|
norm = True
|
||||||
):
|
):
|
||||||
@@ -940,6 +941,14 @@ class ConvNextBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.time_mlp = None
|
||||||
|
|
||||||
|
if exists(time_cond_dim):
|
||||||
|
self.time_mlp = nn.Sequential(
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(time_cond_dim, dim)
|
||||||
|
)
|
||||||
|
|
||||||
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
||||||
|
|
||||||
inner_dim = int(dim_out * mult)
|
inner_dim = int(dim_out * mult)
|
||||||
@@ -952,9 +961,13 @@ class ConvNextBlock(nn.Module):
|
|||||||
|
|
||||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
|
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, cond = None):
|
def forward(self, x, cond = None, time = None):
|
||||||
h = self.ds_conv(x)
|
h = self.ds_conv(x)
|
||||||
|
|
||||||
|
if exists(time) and exists(self.time_mlp):
|
||||||
|
t = self.time_mlp(time)
|
||||||
|
h = rearrange(t, 'b c -> b c 1 1') + h
|
||||||
|
|
||||||
if exists(self.cross_attn):
|
if exists(self.cross_attn):
|
||||||
assert exists(cond)
|
assert exists(cond)
|
||||||
h = self.cross_attn(h, context = cond) + h
|
h = self.cross_attn(h, context = cond) + h
|
||||||
@@ -1059,6 +1072,8 @@ class Unet(nn.Module):
|
|||||||
cond_on_text_encodings = False,
|
cond_on_text_encodings = False,
|
||||||
max_text_len = 256,
|
max_text_len = 256,
|
||||||
cond_on_image_embeds = False,
|
cond_on_image_embeds = False,
|
||||||
|
init_dim = None,
|
||||||
|
init_conv_kernel_size = 7
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# save locals to take care of some hyperparameters for cascading DDPM
|
# save locals to take care of some hyperparameters for cascading DDPM
|
||||||
@@ -1076,22 +1091,34 @@ class Unet(nn.Module):
|
|||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
|
||||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||||
|
init_dim = default(init_dim, dim // 2)
|
||||||
|
|
||||||
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
|
assert (init_conv_kernel_size % 2) == 1
|
||||||
|
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
||||||
|
|
||||||
|
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||||
in_out = list(zip(dims[:-1], dims[1:]))
|
in_out = list(zip(dims[:-1], dims[1:]))
|
||||||
|
|
||||||
# time, image embeddings, and optional text encoding
|
# time, image embeddings, and optional text encoding
|
||||||
|
|
||||||
cond_dim = default(cond_dim, dim)
|
cond_dim = default(cond_dim, dim)
|
||||||
|
time_cond_dim = dim * 4
|
||||||
|
|
||||||
self.time_mlp = nn.Sequential(
|
self.to_time_hiddens = nn.Sequential(
|
||||||
SinusoidalPosEmb(dim),
|
SinusoidalPosEmb(dim),
|
||||||
nn.Linear(dim, dim * 4),
|
nn.Linear(dim, time_cond_dim),
|
||||||
nn.GELU(),
|
nn.GELU()
|
||||||
nn.Linear(dim * 4, cond_dim * num_time_tokens),
|
)
|
||||||
|
|
||||||
|
self.to_time_tokens = nn.Sequential(
|
||||||
|
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
|
||||||
Rearrange('b (r d) -> b r d', r = num_time_tokens)
|
Rearrange('b (r d) -> b r d', r = num_time_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.to_time_cond = nn.Sequential(
|
||||||
|
nn.Linear(time_cond_dim, time_cond_dim)
|
||||||
|
)
|
||||||
|
|
||||||
self.image_to_cond = nn.Sequential(
|
self.image_to_cond = nn.Sequential(
|
||||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||||
@@ -1133,26 +1160,26 @@ class Unet(nn.Module):
|
|||||||
layer_cond_dim = cond_dim if not is_first else None
|
layer_cond_dim = cond_dim if not is_first else None
|
||||||
|
|
||||||
self.downs.append(nn.ModuleList([
|
self.downs.append(nn.ModuleList([
|
||||||
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0),
|
||||||
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
|
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||||
Downsample(dim_out) if not is_last else nn.Identity()
|
Downsample(dim_out) if not is_last else nn.Identity()
|
||||||
]))
|
]))
|
||||||
|
|
||||||
mid_dim = dims[-1]
|
mid_dim = dims[-1]
|
||||||
|
|
||||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||||
|
|
||||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||||
is_last = ind >= (num_resolutions - 2)
|
is_last = ind >= (num_resolutions - 2)
|
||||||
layer_cond_dim = cond_dim if not is_last else None
|
layer_cond_dim = cond_dim if not is_last else None
|
||||||
|
|
||||||
self.ups.append(nn.ModuleList([
|
self.ups.append(nn.ModuleList([
|
||||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
|
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||||
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
|
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||||
Upsample(dim_in)
|
Upsample(dim_in)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
@@ -1214,9 +1241,16 @@ class Unet(nn.Module):
|
|||||||
if exists(lowres_cond_img):
|
if exists(lowres_cond_img):
|
||||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||||
|
|
||||||
|
# initial convolution
|
||||||
|
|
||||||
|
x = self.init_conv(x)
|
||||||
|
|
||||||
# time conditioning
|
# time conditioning
|
||||||
|
|
||||||
time_tokens = self.time_mlp(time)
|
time_hiddens = self.to_time_hiddens(time)
|
||||||
|
|
||||||
|
time_tokens = self.to_time_tokens(time_hiddens)
|
||||||
|
t = self.to_time_cond(time_hiddens)
|
||||||
|
|
||||||
# conditional dropout
|
# conditional dropout
|
||||||
|
|
||||||
@@ -1283,24 +1317,24 @@ class Unet(nn.Module):
|
|||||||
hiddens = []
|
hiddens = []
|
||||||
|
|
||||||
for convnext, sparse_attn, convnext2, downsample in self.downs:
|
for convnext, sparse_attn, convnext2, downsample in self.downs:
|
||||||
x = convnext(x, c)
|
x = convnext(x, c, t)
|
||||||
x = sparse_attn(x)
|
x = sparse_attn(x)
|
||||||
x = convnext2(x, c)
|
x = convnext2(x, c, t)
|
||||||
hiddens.append(x)
|
hiddens.append(x)
|
||||||
x = downsample(x)
|
x = downsample(x)
|
||||||
|
|
||||||
x = self.mid_block1(x, mid_c)
|
x = self.mid_block1(x, mid_c, t)
|
||||||
|
|
||||||
if exists(self.mid_attn):
|
if exists(self.mid_attn):
|
||||||
x = self.mid_attn(x)
|
x = self.mid_attn(x)
|
||||||
|
|
||||||
x = self.mid_block2(x, mid_c)
|
x = self.mid_block2(x, mid_c, t)
|
||||||
|
|
||||||
for convnext, sparse_attn, convnext2, upsample in self.ups:
|
for convnext, sparse_attn, convnext2, upsample in self.ups:
|
||||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||||
x = convnext(x, c)
|
x = convnext(x, c, t)
|
||||||
x = sparse_attn(x)
|
x = sparse_attn(x)
|
||||||
x = convnext2(x, c)
|
x = convnext2(x, c, t)
|
||||||
x = upsample(x)
|
x = upsample(x)
|
||||||
|
|
||||||
return self.final_conv(x)
|
return self.final_conv(x)
|
||||||
@@ -1540,7 +1574,13 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@eval_decorator
|
@eval_decorator
|
||||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
def sample(
|
||||||
|
self,
|
||||||
|
image_embed,
|
||||||
|
text = None,
|
||||||
|
cond_scale = 1.,
|
||||||
|
stop_at_unet_number = None
|
||||||
|
):
|
||||||
batch_size = image_embed.shape[0]
|
batch_size = image_embed.shape[0]
|
||||||
|
|
||||||
text_encodings = text_mask = None
|
text_encodings = text_mask = None
|
||||||
@@ -1552,7 +1592,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
img = None
|
img = None
|
||||||
|
|
||||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||||
|
|
||||||
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
||||||
|
|
||||||
@@ -1584,6 +1624,9 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
img = vae.decode(img)
|
img = vae.decode(img)
|
||||||
|
|
||||||
|
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
|
||||||
|
break
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -144,6 +144,10 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unets(self):
|
||||||
|
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||||
|
|
||||||
def scale(self, loss, *, unet_number):
|
def scale(self, loss, *, unet_number):
|
||||||
assert 1 <= unet_number <= self.num_unets
|
assert 1 <= unet_number <= self.num_unets
|
||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
@@ -169,6 +173,18 @@ class DecoderTrainer(nn.Module):
|
|||||||
ema_unet = self.ema_unets[index]
|
ema_unet = self.ema_unets[index]
|
||||||
ema_unet.update()
|
ema_unet.update()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self, *args, **kwargs):
|
||||||
|
if self.use_ema:
|
||||||
|
trainable_unets = self.decoder.unets
|
||||||
|
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||||
|
|
||||||
|
output = self.decoder.sample(*args, **kwargs)
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
self.decoder.unets = trainable_unets # restore original training unets
|
||||||
|
return output
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
|
|||||||
266
dalle2_pytorch/train_vqgan_vae.py
Normal file
266
dalle2_pytorch/train_vqgan_vae.py
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
from math import sqrt
|
||||||
|
import copy
|
||||||
|
from random import choice
|
||||||
|
from pathlib import Path
|
||||||
|
from shutil import rmtree
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.datasets import ImageFolder
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from torch.utils.data import Dataset, DataLoader, random_split
|
||||||
|
from torchvision.utils import make_grid, save_image
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from dalle2_pytorch.train import EMA
|
||||||
|
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||||
|
from dalle2_pytorch.optimizer import get_optimizer
|
||||||
|
|
||||||
|
# helpers
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
def noop(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def cycle(dl):
|
||||||
|
while True:
|
||||||
|
for data in dl:
|
||||||
|
yield data
|
||||||
|
|
||||||
|
def cast_tuple(t):
|
||||||
|
return t if isinstance(t, (tuple, list)) else (t,)
|
||||||
|
|
||||||
|
def yes_or_no(question):
|
||||||
|
answer = input(f'{question} (y/n) ')
|
||||||
|
return answer.lower() in ('yes', 'y')
|
||||||
|
|
||||||
|
def accum_log(log, new_logs):
|
||||||
|
for key, new_value in new_logs.items():
|
||||||
|
old_value = log.get(key, 0.)
|
||||||
|
log[key] = old_value + new_value
|
||||||
|
return log
|
||||||
|
|
||||||
|
# classes
|
||||||
|
|
||||||
|
class ImageDataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
folder,
|
||||||
|
image_size,
|
||||||
|
exts = ['jpg', 'jpeg', 'png']
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.folder = folder
|
||||||
|
self.image_size = image_size
|
||||||
|
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
||||||
|
|
||||||
|
print(f'{len(self.paths)} training samples found at {folder}')
|
||||||
|
|
||||||
|
self.transform = T.Compose([
|
||||||
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||||
|
T.Resize(image_size),
|
||||||
|
T.RandomHorizontalFlip(),
|
||||||
|
T.CenterCrop(image_size),
|
||||||
|
T.ToTensor()
|
||||||
|
])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
path = self.paths[index]
|
||||||
|
img = Image.open(path)
|
||||||
|
return self.transform(img)
|
||||||
|
|
||||||
|
# main trainer class
|
||||||
|
|
||||||
|
class VQGanVAETrainer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vae,
|
||||||
|
*,
|
||||||
|
num_train_steps,
|
||||||
|
lr,
|
||||||
|
batch_size,
|
||||||
|
folder,
|
||||||
|
grad_accum_every,
|
||||||
|
wd = 0.,
|
||||||
|
save_results_every = 100,
|
||||||
|
save_model_every = 1000,
|
||||||
|
results_folder = './results',
|
||||||
|
valid_frac = 0.05,
|
||||||
|
random_split_seed = 42,
|
||||||
|
ema_beta = 0.995,
|
||||||
|
ema_update_after_step = 2000,
|
||||||
|
ema_update_every = 10,
|
||||||
|
apply_grad_penalty_every = 4,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
|
||||||
|
image_size = vae.image_size
|
||||||
|
|
||||||
|
self.vae = vae
|
||||||
|
self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)
|
||||||
|
|
||||||
|
self.register_buffer('steps', torch.Tensor([0]))
|
||||||
|
|
||||||
|
self.num_train_steps = num_train_steps
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.grad_accum_every = grad_accum_every
|
||||||
|
|
||||||
|
all_parameters = set(vae.parameters())
|
||||||
|
discr_parameters = set(vae.discr.parameters())
|
||||||
|
vae_parameters = all_parameters - discr_parameters
|
||||||
|
|
||||||
|
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
|
||||||
|
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
|
||||||
|
|
||||||
|
# create dataset
|
||||||
|
|
||||||
|
self.ds = ImageDataset(folder, image_size = image_size)
|
||||||
|
|
||||||
|
# split for validation
|
||||||
|
|
||||||
|
if valid_frac > 0:
|
||||||
|
train_size = int((1 - valid_frac) * len(self.ds))
|
||||||
|
valid_size = len(self.ds) - train_size
|
||||||
|
self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
|
||||||
|
print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
|
||||||
|
else:
|
||||||
|
self.valid_ds = self.ds
|
||||||
|
print(f'training with shared training and valid dataset of {len(self.ds)} samples')
|
||||||
|
|
||||||
|
# dataloader
|
||||||
|
|
||||||
|
self.dl = cycle(DataLoader(
|
||||||
|
self.ds,
|
||||||
|
batch_size = batch_size,
|
||||||
|
shuffle = True
|
||||||
|
))
|
||||||
|
|
||||||
|
self.valid_dl = cycle(DataLoader(
|
||||||
|
self.valid_ds,
|
||||||
|
batch_size = batch_size,
|
||||||
|
shuffle = True
|
||||||
|
))
|
||||||
|
|
||||||
|
self.save_model_every = save_model_every
|
||||||
|
self.save_results_every = save_results_every
|
||||||
|
|
||||||
|
self.apply_grad_penalty_every = apply_grad_penalty_every
|
||||||
|
|
||||||
|
self.results_folder = Path(results_folder)
|
||||||
|
|
||||||
|
if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
|
||||||
|
rmtree(str(self.results_folder))
|
||||||
|
|
||||||
|
self.results_folder.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
def train_step(self):
|
||||||
|
device = next(self.vae.parameters()).device
|
||||||
|
steps = int(self.steps.item())
|
||||||
|
apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
|
||||||
|
|
||||||
|
self.vae.train()
|
||||||
|
|
||||||
|
# logs
|
||||||
|
|
||||||
|
logs = {}
|
||||||
|
|
||||||
|
# update vae (generator)
|
||||||
|
|
||||||
|
for _ in range(self.grad_accum_every):
|
||||||
|
img = next(self.dl)
|
||||||
|
img = img.to(device)
|
||||||
|
|
||||||
|
loss = self.vae(
|
||||||
|
img,
|
||||||
|
return_loss = True,
|
||||||
|
apply_grad_penalty = apply_grad_penalty
|
||||||
|
)
|
||||||
|
|
||||||
|
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
|
||||||
|
|
||||||
|
(loss / self.grad_accum_every).backward()
|
||||||
|
|
||||||
|
self.optim.step()
|
||||||
|
self.optim.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
|
# update discriminator
|
||||||
|
|
||||||
|
if exists(self.vae.discr):
|
||||||
|
discr_loss = 0
|
||||||
|
for _ in range(self.grad_accum_every):
|
||||||
|
img = next(self.dl)
|
||||||
|
img = img.to(device)
|
||||||
|
|
||||||
|
loss = self.vae(img, return_discr_loss = True)
|
||||||
|
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
|
||||||
|
|
||||||
|
(loss / self.grad_accum_every).backward()
|
||||||
|
|
||||||
|
self.discr_optim.step()
|
||||||
|
self.discr_optim.zero_grad()
|
||||||
|
|
||||||
|
# log
|
||||||
|
|
||||||
|
print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")
|
||||||
|
|
||||||
|
# update exponential moving averaged generator
|
||||||
|
|
||||||
|
self.ema_vae.update()
|
||||||
|
|
||||||
|
# sample results every so often
|
||||||
|
|
||||||
|
if not (steps % self.save_results_every):
|
||||||
|
for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
imgs = next(self.dl)
|
||||||
|
imgs = imgs.to(device)
|
||||||
|
|
||||||
|
recons = model(imgs)
|
||||||
|
nrows = int(sqrt(self.batch_size))
|
||||||
|
|
||||||
|
imgs_and_recons = torch.stack((imgs, recons), dim = 0)
|
||||||
|
imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')
|
||||||
|
|
||||||
|
imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
|
||||||
|
grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))
|
||||||
|
|
||||||
|
logs['reconstructions'] = grid
|
||||||
|
|
||||||
|
save_image(grid, str(self.results_folder / f'{filename}.png'))
|
||||||
|
|
||||||
|
print(f'{steps}: saving to {str(self.results_folder)}')
|
||||||
|
|
||||||
|
# save model every so often
|
||||||
|
|
||||||
|
if not (steps % self.save_model_every):
|
||||||
|
state_dict = self.vae.state_dict()
|
||||||
|
model_path = str(self.results_folder / f'vae.{steps}.pt')
|
||||||
|
torch.save(state_dict, model_path)
|
||||||
|
|
||||||
|
ema_state_dict = self.ema_vae.state_dict()
|
||||||
|
model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
|
||||||
|
torch.save(ema_state_dict, model_path)
|
||||||
|
|
||||||
|
print(f'{steps}: saving model to {str(self.results_folder)}')
|
||||||
|
|
||||||
|
self.steps += 1
|
||||||
|
return logs
|
||||||
|
|
||||||
|
def train(self, log_fn = noop):
|
||||||
|
device = next(self.vae.parameters()).device
|
||||||
|
|
||||||
|
while self.steps < self.num_train_steps:
|
||||||
|
logs = self.train_step()
|
||||||
|
log_fn(logs)
|
||||||
|
|
||||||
|
print('training complete')
|
||||||
@@ -327,6 +327,108 @@ class ResBlock(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x) + x
|
return self.net(x) + x
|
||||||
|
|
||||||
|
# convnext enc dec
|
||||||
|
|
||||||
|
class ChanLayerNorm(nn.Module):
|
||||||
|
def __init__(self, dim, eps = 1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||||
|
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||||
|
return (x - mean) / (var + self.eps).sqrt() * self.g
|
||||||
|
|
||||||
|
class ConvNext(nn.Module):
|
||||||
|
def __init__(self, dim, mult = 4, kernel_size = 3, ds_kernel_size = 7):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Conv2d(dim, dim, ds_kernel_size, padding = ds_kernel_size // 2, groups = dim),
|
||||||
|
ChanLayerNorm(dim),
|
||||||
|
nn.Conv2d(dim, inner_dim, kernel_size, padding = kernel_size // 2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv2d(inner_dim, dim, kernel_size, padding = kernel_size // 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x) + x
|
||||||
|
|
||||||
|
class ConvNextEncDec(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
*,
|
||||||
|
channels = 3,
|
||||||
|
layers = 4,
|
||||||
|
layer_mults = None,
|
||||||
|
num_blocks = 1,
|
||||||
|
first_conv_kernel_size = 5,
|
||||||
|
use_attn = True,
|
||||||
|
attn_dim_head = 64,
|
||||||
|
attn_heads = 8,
|
||||||
|
attn_dropout = 0.,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = layers
|
||||||
|
|
||||||
|
self.encoders = MList([])
|
||||||
|
self.decoders = MList([])
|
||||||
|
|
||||||
|
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
||||||
|
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
||||||
|
|
||||||
|
layer_dims = [dim * mult for mult in layer_mults]
|
||||||
|
dims = (dim, *layer_dims)
|
||||||
|
|
||||||
|
self.encoded_dim = dims[-1]
|
||||||
|
|
||||||
|
dim_pairs = zip(dims[:-1], dims[1:])
|
||||||
|
|
||||||
|
append = lambda arr, t: arr.append(t)
|
||||||
|
prepend = lambda arr, t: arr.insert(0, t)
|
||||||
|
|
||||||
|
if not isinstance(num_blocks, tuple):
|
||||||
|
num_blocks = (*((0,) * (layers - 1)), num_blocks)
|
||||||
|
|
||||||
|
if not isinstance(use_attn, tuple):
|
||||||
|
use_attn = (*((False,) * (layers - 1)), use_attn)
|
||||||
|
|
||||||
|
assert len(num_blocks) == layers, 'number of blocks config must be equal to number of layers'
|
||||||
|
assert len(use_attn) == layers
|
||||||
|
|
||||||
|
for layer_index, (dim_in, dim_out), layer_num_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_blocks, use_attn):
|
||||||
|
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||||
|
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
|
||||||
|
|
||||||
|
if layer_use_attn:
|
||||||
|
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||||
|
|
||||||
|
for _ in range(layer_num_blocks):
|
||||||
|
append(self.encoders, ConvNext(dim_out))
|
||||||
|
prepend(self.decoders, ConvNext(dim_out))
|
||||||
|
|
||||||
|
if layer_use_attn:
|
||||||
|
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||||
|
|
||||||
|
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
||||||
|
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
||||||
|
|
||||||
|
def get_encoded_fmap_size(self, image_size):
|
||||||
|
return image_size // (2 ** self.layers)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
for enc in self.encoders:
|
||||||
|
x = enc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
for dec in self.decoders:
|
||||||
|
x = dec(x)
|
||||||
|
return x
|
||||||
|
|
||||||
# vqgan attention layer
|
# vqgan attention layer
|
||||||
|
|
||||||
class VQGanAttention(nn.Module):
|
class VQGanAttention(nn.Module):
|
||||||
@@ -568,6 +670,8 @@ class VQGanVAE(nn.Module):
|
|||||||
enc_dec_klass = ResnetEncDec
|
enc_dec_klass = ResnetEncDec
|
||||||
elif vae_type == 'vit':
|
elif vae_type == 'vit':
|
||||||
enc_dec_klass = ViTEncDec
|
enc_dec_klass = ViTEncDec
|
||||||
|
elif vae_type == 'convnext':
|
||||||
|
enc_dec_klass = ConvNextEncDec
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'{vae_type} not valid')
|
raise ValueError(f'{vae_type} not valid')
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user