mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 03:24:22 +01:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8b9bbec7d1 | ||
|
|
1bb9fc9829 | ||
|
|
5e421bd5bb | ||
|
|
67fcab1122 | ||
|
|
5bfbccda22 | ||
|
|
989275ff59 | ||
|
|
56408f4a40 | ||
|
|
d1a697ac23 | ||
|
|
ebe01749ed | ||
|
|
63195cc2cb | ||
|
|
a2ef69af66 |
15
README.md
15
README.md
@@ -760,7 +760,7 @@ decoder = Decoder(
|
||||
unet = (unet1, unet2),
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 1,
|
||||
timesteps = 1000,
|
||||
condition_on_text_encodings = True
|
||||
).cuda()
|
||||
|
||||
@@ -778,6 +778,12 @@ for unet_number in (1, 2):
|
||||
loss.backward()
|
||||
|
||||
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||
|
||||
# after much training
|
||||
# you can sample from the exponentially moving averaged unets as so
|
||||
|
||||
mock_image_embed = torch.randn(4, 512).cuda()
|
||||
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||
```
|
||||
|
||||
## CLI (wip)
|
||||
@@ -811,15 +817,16 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
||||
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
||||
- [ ] take care of mixed precision as well as gradient accumulation within decoder trainer
|
||||
- [x] take care of mixed precision as well as gradient accumulation within decoder trainer
|
||||
- [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
|
||||
- [x] bring in tools to train vqgan-vae
|
||||
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in tools to train vqgan-vae
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
@@ -922,6 +922,7 @@ class ConvNextBlock(nn.Module):
|
||||
dim_out,
|
||||
*,
|
||||
cond_dim = None,
|
||||
time_cond_dim = None,
|
||||
mult = 2,
|
||||
norm = True
|
||||
):
|
||||
@@ -940,6 +941,14 @@ class ConvNextBlock(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
self.time_mlp = None
|
||||
|
||||
if exists(time_cond_dim):
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.GELU(),
|
||||
nn.Linear(time_cond_dim, dim)
|
||||
)
|
||||
|
||||
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
|
||||
|
||||
inner_dim = int(dim_out * mult)
|
||||
@@ -952,9 +961,13 @@ class ConvNextBlock(nn.Module):
|
||||
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
|
||||
|
||||
def forward(self, x, cond = None):
|
||||
def forward(self, x, cond = None, time = None):
|
||||
h = self.ds_conv(x)
|
||||
|
||||
if exists(time) and exists(self.time_mlp):
|
||||
t = self.time_mlp(time)
|
||||
h = rearrange(t, 'b c -> b c 1 1') + h
|
||||
|
||||
if exists(self.cross_attn):
|
||||
assert exists(cond)
|
||||
h = self.cross_attn(h, context = cond) + h
|
||||
@@ -1059,6 +1072,8 @@ class Unet(nn.Module):
|
||||
cond_on_text_encodings = False,
|
||||
max_text_len = 256,
|
||||
cond_on_image_embeds = False,
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7
|
||||
):
|
||||
super().__init__()
|
||||
# save locals to take care of some hyperparameters for cascading DDPM
|
||||
@@ -1076,22 +1091,34 @@ class Unet(nn.Module):
|
||||
self.channels = channels
|
||||
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
init_dim = default(init_dim, dim // 2)
|
||||
|
||||
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
|
||||
assert (init_conv_kernel_size % 2) == 1
|
||||
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
||||
|
||||
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
# time, image embeddings, and optional text encoding
|
||||
|
||||
cond_dim = default(cond_dim, dim)
|
||||
time_cond_dim = dim * 4
|
||||
|
||||
self.time_mlp = nn.Sequential(
|
||||
self.to_time_hiddens = nn.Sequential(
|
||||
SinusoidalPosEmb(dim),
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim * 4, cond_dim * num_time_tokens),
|
||||
nn.Linear(dim, time_cond_dim),
|
||||
nn.GELU()
|
||||
)
|
||||
|
||||
self.to_time_tokens = nn.Sequential(
|
||||
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
|
||||
Rearrange('b (r d) -> b r d', r = num_time_tokens)
|
||||
)
|
||||
|
||||
self.to_time_cond = nn.Sequential(
|
||||
nn.Linear(time_cond_dim, time_cond_dim)
|
||||
)
|
||||
|
||||
self.image_to_cond = nn.Sequential(
|
||||
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
@@ -1133,26 +1160,26 @@ class Unet(nn.Module):
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
||||
ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0),
|
||||
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
|
||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (num_resolutions - 2)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
|
||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
|
||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
|
||||
@@ -1214,9 +1241,16 @@ class Unet(nn.Module):
|
||||
if exists(lowres_cond_img):
|
||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||
|
||||
# initial convolution
|
||||
|
||||
x = self.init_conv(x)
|
||||
|
||||
# time conditioning
|
||||
|
||||
time_tokens = self.time_mlp(time)
|
||||
time_hiddens = self.to_time_hiddens(time)
|
||||
|
||||
time_tokens = self.to_time_tokens(time_hiddens)
|
||||
t = self.to_time_cond(time_hiddens)
|
||||
|
||||
# conditional dropout
|
||||
|
||||
@@ -1283,24 +1317,24 @@ class Unet(nn.Module):
|
||||
hiddens = []
|
||||
|
||||
for convnext, sparse_attn, convnext2, downsample in self.downs:
|
||||
x = convnext(x, c)
|
||||
x = convnext(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = convnext2(x, c)
|
||||
x = convnext2(x, c, t)
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block1(x, mid_c)
|
||||
x = self.mid_block1(x, mid_c, t)
|
||||
|
||||
if exists(self.mid_attn):
|
||||
x = self.mid_attn(x)
|
||||
|
||||
x = self.mid_block2(x, mid_c)
|
||||
x = self.mid_block2(x, mid_c, t)
|
||||
|
||||
for convnext, sparse_attn, convnext2, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = convnext(x, c)
|
||||
x = convnext(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = convnext2(x, c)
|
||||
x = convnext2(x, c, t)
|
||||
x = upsample(x)
|
||||
|
||||
return self.final_conv(x)
|
||||
@@ -1540,7 +1574,13 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||
def sample(
|
||||
self,
|
||||
image_embed,
|
||||
text = None,
|
||||
cond_scale = 1.,
|
||||
stop_at_unet_number = None
|
||||
):
|
||||
batch_size = image_embed.shape[0]
|
||||
|
||||
text_encodings = text_mask = None
|
||||
@@ -1552,7 +1592,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
img = None
|
||||
|
||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
||||
|
||||
@@ -1584,6 +1624,9 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
img = vae.decode(img)
|
||||
|
||||
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
|
||||
break
|
||||
|
||||
return img
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -3,6 +3,7 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
@@ -98,6 +99,7 @@ class DecoderTrainer(nn.Module):
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -115,6 +117,8 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.ema_unets = nn.ModuleList([])
|
||||
|
||||
self.amp = amp
|
||||
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
@@ -133,10 +137,23 @@ class DecoderTrainer(nn.Module):
|
||||
if self.use_ema:
|
||||
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||
|
||||
scaler = GradScaler(enabled = amp)
|
||||
setattr(self, f'scaler{ind}', scaler)
|
||||
|
||||
# gradient clipping if needed
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
@property
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
|
||||
def scale(self, loss, *, unet_number):
|
||||
assert 1 <= unet_number <= self.num_unets
|
||||
index = unet_number - 1
|
||||
scaler = getattr(self, f'scaler{index}')
|
||||
return scaler.scale(loss)
|
||||
|
||||
def update(self, unet_number):
|
||||
assert 1 <= unet_number <= self.num_unets
|
||||
index = unet_number - 1
|
||||
@@ -146,12 +163,36 @@ class DecoderTrainer(nn.Module):
|
||||
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
||||
|
||||
optimizer = getattr(self, f'optim{index}')
|
||||
optimizer.step()
|
||||
scaler = getattr(self, f'scaler{index}')
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if self.use_ema:
|
||||
ema_unet = self.ema_unets[index]
|
||||
ema_unet.update()
|
||||
|
||||
def forward(self, x, *, unet_number, **kwargs):
|
||||
return self.decoder(x, unet_number = unet_number, **kwargs)
|
||||
@torch.no_grad()
|
||||
def sample(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
trainable_unets = self.decoder.unets
|
||||
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||
|
||||
output = self.decoder.sample(*args, **kwargs)
|
||||
|
||||
if self.use_ema:
|
||||
self.decoder.unets = trainable_unets # restore original training unets
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
*,
|
||||
unet_number,
|
||||
divisor = 1,
|
||||
**kwargs
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
||||
return self.scale(loss / divisor, unet_number = unet_number)
|
||||
|
||||
266
dalle2_pytorch/train_vqgan_vae.py
Normal file
266
dalle2_pytorch/train_vqgan_vae.py
Normal file
@@ -0,0 +1,266 @@
|
||||
from math import sqrt
|
||||
import copy
|
||||
from random import choice
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from PIL import Image
|
||||
from torchvision.datasets import ImageFolder
|
||||
import torchvision.transforms as T
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
from torchvision.utils import make_grid, save_image
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from dalle2_pytorch.train import EMA
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for data in dl:
|
||||
yield data
|
||||
|
||||
def cast_tuple(t):
|
||||
return t if isinstance(t, (tuple, list)) else (t,)
|
||||
|
||||
def yes_or_no(question):
|
||||
answer = input(f'{question} (y/n) ')
|
||||
return answer.lower() in ('yes', 'y')
|
||||
|
||||
def accum_log(log, new_logs):
|
||||
for key, new_value in new_logs.items():
|
||||
old_value = log.get(key, 0.)
|
||||
log[key] = old_value + new_value
|
||||
return log
|
||||
|
||||
# classes
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
folder,
|
||||
image_size,
|
||||
exts = ['jpg', 'jpeg', 'png']
|
||||
):
|
||||
super().__init__()
|
||||
self.folder = folder
|
||||
self.image_size = image_size
|
||||
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
||||
|
||||
print(f'{len(self.paths)} training samples found at {folder}')
|
||||
|
||||
self.transform = T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||
T.Resize(image_size),
|
||||
T.RandomHorizontalFlip(),
|
||||
T.CenterCrop(image_size),
|
||||
T.ToTensor()
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.paths[index]
|
||||
img = Image.open(path)
|
||||
return self.transform(img)
|
||||
|
||||
# main trainer class
|
||||
|
||||
class VQGanVAETrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vae,
|
||||
*,
|
||||
num_train_steps,
|
||||
lr,
|
||||
batch_size,
|
||||
folder,
|
||||
grad_accum_every,
|
||||
wd = 0.,
|
||||
save_results_every = 100,
|
||||
save_model_every = 1000,
|
||||
results_folder = './results',
|
||||
valid_frac = 0.05,
|
||||
random_split_seed = 42,
|
||||
ema_beta = 0.995,
|
||||
ema_update_after_step = 2000,
|
||||
ema_update_every = 10,
|
||||
apply_grad_penalty_every = 4,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
|
||||
image_size = vae.image_size
|
||||
|
||||
self.vae = vae
|
||||
self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)
|
||||
|
||||
self.register_buffer('steps', torch.Tensor([0]))
|
||||
|
||||
self.num_train_steps = num_train_steps
|
||||
self.batch_size = batch_size
|
||||
self.grad_accum_every = grad_accum_every
|
||||
|
||||
all_parameters = set(vae.parameters())
|
||||
discr_parameters = set(vae.discr.parameters())
|
||||
vae_parameters = all_parameters - discr_parameters
|
||||
|
||||
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
|
||||
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
|
||||
|
||||
# create dataset
|
||||
|
||||
self.ds = ImageDataset(folder, image_size = image_size)
|
||||
|
||||
# split for validation
|
||||
|
||||
if valid_frac > 0:
|
||||
train_size = int((1 - valid_frac) * len(self.ds))
|
||||
valid_size = len(self.ds) - train_size
|
||||
self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
|
||||
print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
|
||||
else:
|
||||
self.valid_ds = self.ds
|
||||
print(f'training with shared training and valid dataset of {len(self.ds)} samples')
|
||||
|
||||
# dataloader
|
||||
|
||||
self.dl = cycle(DataLoader(
|
||||
self.ds,
|
||||
batch_size = batch_size,
|
||||
shuffle = True
|
||||
))
|
||||
|
||||
self.valid_dl = cycle(DataLoader(
|
||||
self.valid_ds,
|
||||
batch_size = batch_size,
|
||||
shuffle = True
|
||||
))
|
||||
|
||||
self.save_model_every = save_model_every
|
||||
self.save_results_every = save_results_every
|
||||
|
||||
self.apply_grad_penalty_every = apply_grad_penalty_every
|
||||
|
||||
self.results_folder = Path(results_folder)
|
||||
|
||||
if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
|
||||
rmtree(str(self.results_folder))
|
||||
|
||||
self.results_folder.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
def train_step(self):
|
||||
device = next(self.vae.parameters()).device
|
||||
steps = int(self.steps.item())
|
||||
apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
|
||||
|
||||
self.vae.train()
|
||||
|
||||
# logs
|
||||
|
||||
logs = {}
|
||||
|
||||
# update vae (generator)
|
||||
|
||||
for _ in range(self.grad_accum_every):
|
||||
img = next(self.dl)
|
||||
img = img.to(device)
|
||||
|
||||
loss = self.vae(
|
||||
img,
|
||||
return_loss = True,
|
||||
apply_grad_penalty = apply_grad_penalty
|
||||
)
|
||||
|
||||
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
|
||||
|
||||
(loss / self.grad_accum_every).backward()
|
||||
|
||||
self.optim.step()
|
||||
self.optim.zero_grad()
|
||||
|
||||
|
||||
# update discriminator
|
||||
|
||||
if exists(self.vae.discr):
|
||||
discr_loss = 0
|
||||
for _ in range(self.grad_accum_every):
|
||||
img = next(self.dl)
|
||||
img = img.to(device)
|
||||
|
||||
loss = self.vae(img, return_discr_loss = True)
|
||||
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
|
||||
|
||||
(loss / self.grad_accum_every).backward()
|
||||
|
||||
self.discr_optim.step()
|
||||
self.discr_optim.zero_grad()
|
||||
|
||||
# log
|
||||
|
||||
print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")
|
||||
|
||||
# update exponential moving averaged generator
|
||||
|
||||
self.ema_vae.update()
|
||||
|
||||
# sample results every so often
|
||||
|
||||
if not (steps % self.save_results_every):
|
||||
for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
|
||||
model.eval()
|
||||
|
||||
imgs = next(self.dl)
|
||||
imgs = imgs.to(device)
|
||||
|
||||
recons = model(imgs)
|
||||
nrows = int(sqrt(self.batch_size))
|
||||
|
||||
imgs_and_recons = torch.stack((imgs, recons), dim = 0)
|
||||
imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')
|
||||
|
||||
imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
|
||||
grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))
|
||||
|
||||
logs['reconstructions'] = grid
|
||||
|
||||
save_image(grid, str(self.results_folder / f'{filename}.png'))
|
||||
|
||||
print(f'{steps}: saving to {str(self.results_folder)}')
|
||||
|
||||
# save model every so often
|
||||
|
||||
if not (steps % self.save_model_every):
|
||||
state_dict = self.vae.state_dict()
|
||||
model_path = str(self.results_folder / f'vae.{steps}.pt')
|
||||
torch.save(state_dict, model_path)
|
||||
|
||||
ema_state_dict = self.ema_vae.state_dict()
|
||||
model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
|
||||
torch.save(ema_state_dict, model_path)
|
||||
|
||||
print(f'{steps}: saving model to {str(self.results_folder)}')
|
||||
|
||||
self.steps += 1
|
||||
return logs
|
||||
|
||||
def train(self, log_fn = noop):
|
||||
device = next(self.vae.parameters()).device
|
||||
|
||||
while self.steps < self.num_train_steps:
|
||||
logs = self.train_step()
|
||||
log_fn(logs)
|
||||
|
||||
print('training complete')
|
||||
@@ -327,6 +327,108 @@ class ResBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
# convnext enc dec
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g
|
||||
|
||||
class ConvNext(nn.Module):
|
||||
def __init__(self, dim, mult = 4, kernel_size = 3, ds_kernel_size = 7):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim, ds_kernel_size, padding = ds_kernel_size // 2, groups = dim),
|
||||
ChanLayerNorm(dim),
|
||||
nn.Conv2d(dim, inner_dim, kernel_size, padding = kernel_size // 2),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim, dim, kernel_size, padding = kernel_size // 2)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
class ConvNextEncDec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
channels = 3,
|
||||
layers = 4,
|
||||
layer_mults = None,
|
||||
num_blocks = 1,
|
||||
first_conv_kernel_size = 5,
|
||||
use_attn = True,
|
||||
attn_dim_head = 64,
|
||||
attn_heads = 8,
|
||||
attn_dropout = 0.,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.layers = layers
|
||||
|
||||
self.encoders = MList([])
|
||||
self.decoders = MList([])
|
||||
|
||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
||||
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
||||
|
||||
layer_dims = [dim * mult for mult in layer_mults]
|
||||
dims = (dim, *layer_dims)
|
||||
|
||||
self.encoded_dim = dims[-1]
|
||||
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
append = lambda arr, t: arr.append(t)
|
||||
prepend = lambda arr, t: arr.insert(0, t)
|
||||
|
||||
if not isinstance(num_blocks, tuple):
|
||||
num_blocks = (*((0,) * (layers - 1)), num_blocks)
|
||||
|
||||
if not isinstance(use_attn, tuple):
|
||||
use_attn = (*((False,) * (layers - 1)), use_attn)
|
||||
|
||||
assert len(num_blocks) == layers, 'number of blocks config must be equal to number of layers'
|
||||
assert len(use_attn) == layers
|
||||
|
||||
for layer_index, (dim_in, dim_out), layer_num_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_blocks, use_attn):
|
||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
|
||||
|
||||
if layer_use_attn:
|
||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
for _ in range(layer_num_blocks):
|
||||
append(self.encoders, ConvNext(dim_out))
|
||||
prepend(self.decoders, ConvNext(dim_out))
|
||||
|
||||
if layer_use_attn:
|
||||
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
||||
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
||||
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
|
||||
def encode(self, x):
|
||||
for enc in self.encoders:
|
||||
x = enc(x)
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
for dec in self.decoders:
|
||||
x = dec(x)
|
||||
return x
|
||||
|
||||
# vqgan attention layer
|
||||
|
||||
class VQGanAttention(nn.Module):
|
||||
@@ -568,6 +670,8 @@ class VQGanVAE(nn.Module):
|
||||
enc_dec_klass = ResnetEncDec
|
||||
elif vae_type == 'vit':
|
||||
enc_dec_klass = ViTEncDec
|
||||
elif vae_type == 'convnext':
|
||||
enc_dec_klass = ConvNextEncDec
|
||||
else:
|
||||
raise ValueError(f'{vae_type} not valid')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user