Compare commits

...

10 Commits

5 changed files with 122 additions and 33 deletions

View File

@@ -732,8 +732,8 @@ clip = CLIP(
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(32, 3, 256, 256).cuda()
# decoder (with unet)
@@ -774,8 +774,12 @@ decoder_trainer = DecoderTrainer(
)
for unet_number in (1, 2):
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
loss.backward()
loss = decoder_trainer(
images,
text = text,
unet_number = unet_number, # which unet to train on
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
)
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
@@ -839,7 +843,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
)
loss = diffusion_prior_trainer(text, images)
loss.backward()
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop
@@ -1017,6 +1020,7 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well
## Citations

View File

@@ -1,7 +1,7 @@
import math
from tqdm import tqdm
from inspect import isfunction
from functools import partial
from functools import partial, wraps
from contextlib import contextmanager
from collections import namedtuple
from pathlib import Path
@@ -45,6 +45,14 @@ def exists(val):
def identity(t, *args, **kwargs):
return t
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def default(val, d):
if exists(val):
return val
@@ -606,7 +614,6 @@ class Attention(nn.Module):
heads = 8,
dropout = 0.,
causal = False,
post_norm = False,
rotary_emb = None
):
super().__init__()
@@ -616,7 +623,6 @@ class Attention(nn.Module):
self.causal = causal
self.norm = LayerNorm(dim)
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
@@ -627,7 +633,7 @@ class Attention(nn.Module):
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity()
LayerNorm(dim)
)
def forward(self, x, mask = None, attn_bias = None):
@@ -684,8 +690,7 @@ class Attention(nn.Module):
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return self.post_norm(out)
return self.to_out(out)
class CausalTransformer(nn.Module):
def __init__(
@@ -711,7 +716,7 @@ class CausalTransformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
@@ -1158,6 +1163,7 @@ class CrossAttention(nn.Module):
dim_head = 64,
heads = 8,
dropout = 0.,
norm_context = False
):
super().__init__()
self.scale = dim_head ** -0.5
@@ -1167,13 +1173,17 @@ class CrossAttention(nn.Module):
context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim)
self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
@@ -1369,6 +1379,9 @@ class Unet(nn.Module):
Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity()
self.norm_cond = nn.LayerNorm(cond_dim)
self.norm_mid_cond = nn.LayerNorm(cond_dim)
# text encoding conditioning (optional)
self.text_to_cond = None
@@ -1584,6 +1597,11 @@ class Unet(nn.Module):
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
# normalize conditioning tokens
c = self.norm_cond(c)
mid_c = self.norm_mid_cond(mid_c)
# go through the layers of the unet, down and up
hiddens = []
@@ -1844,6 +1862,8 @@ class Decoder(BaseGaussianDiffusion):
b = shape[0]
img = torch.randn(shape, device = device)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(
unet,
@@ -1868,9 +1888,7 @@ class Decoder(BaseGaussianDiffusion):
# normalize to [-1, 1]
x_start = normalize_neg_one_to_one(x_start)
if exists(lowres_cond_img):
lowres_cond_img = normalize_neg_one_to_one(lowres_cond_img)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
# get x_t

View File

@@ -7,16 +7,17 @@ def separate_weight_decayable_params(params):
def get_optimizer(
params,
lr = 3e-4,
lr = 2e-5,
wd = 1e-2,
betas = (0.9, 0.999),
eps = 1e-8,
filter_by_requires_grad = False
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))
if wd == 0:
return Adam(params, lr = lr, betas = betas)
return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)
@@ -26,4 +27,4 @@ def get_optimizer(
{'params': list(no_wd_params), 'weight_decay': 0},
]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -1,6 +1,8 @@
import time
import copy
from math import ceil
from functools import partial
from collections.abc import Iterable
import torch
from torch import nn
@@ -14,6 +16,9 @@ from dalle2_pytorch.optimizer import get_optimizer
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
@@ -40,6 +45,47 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# gradient accumulation functions
def split_iterable(it, split_size):
accum = []
for ind in range(ceil(len(it) / split_size)):
start_index = ind * split_size
accum.append(it[start_index: (start_index + split_size)])
return accum
def split(t, split_size = None):
if not exists(split_size):
return t
if isinstance(t, torch.Tensor):
return t.split(split_size, dim = 0)
if isinstance(t, Iterable):
return split_iterable(t, split_size)
return TypeError
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
batch_size = len(x)
split_size = default(split_size, batch_size)
chunk_size = ceil(batch_size / split_size)
dict_len = len(kwargs)
dict_keys = kwargs.keys()
all_args = (x, *args, *kwargs.values())
len_all_args = len(all_args)
split_kwargs_index = len_all_args - dict_len
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)
# print helpers
def print_ribbon(s, symbol = '=', repeat = 40):
@@ -90,7 +136,7 @@ class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.99,
beta = 0.9999,
update_after_step = 1000,
update_every = 10,
):
@@ -147,6 +193,7 @@ class DiffusionPriorTrainer(nn.Module):
use_ema = True,
lr = 3e-4,
wd = 1e-2,
eps = 1e-6,
max_grad_norm = None,
amp = False,
**kwargs
@@ -173,6 +220,7 @@ class DiffusionPriorTrainer(nn.Module):
diffusion_prior.parameters(),
lr = lr,
wd = wd,
eps = eps,
**kwargs
)
@@ -206,13 +254,21 @@ class DiffusionPriorTrainer(nn.Module):
def forward(
self,
x,
*args,
divisor = 1,
max_batch_size = None,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*args, **kwargs)
return self.scaler.scale(loss / divisor)
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
total_loss += loss.item() * chunk_size_frac
self.scaler.scale(loss * chunk_size_frac).backward()
return total_loss
# decoder trainer
@@ -221,8 +277,9 @@ class DecoderTrainer(nn.Module):
self,
decoder,
use_ema = True,
lr = 3e-4,
lr = 2e-5,
wd = 1e-2,
eps = 1e-8,
max_grad_norm = None,
amp = False,
**kwargs
@@ -247,13 +304,14 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay
# per unet
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
eps = unet_eps,
**kwargs
)
@@ -321,9 +379,17 @@ class DecoderTrainer(nn.Module):
x,
*,
unet_number,
divisor = 1,
max_batch_size = None,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss / divisor, unet_number = unet_number)
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
total_loss += loss.item() * chunk_size_frac
self.scale(loss * chunk_size_frac, unet_number = unet_number).backward()
return total_loss

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.16',
version = '0.2.27',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',