mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 17:14:38 +01:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9549bd43b7 | ||
|
|
aee92dba4a | ||
|
|
b0cd5f24b6 | ||
|
|
b494ed81d4 | ||
|
|
ff3474f05c | ||
|
|
d5293f19f1 | ||
|
|
e697183849 | ||
|
|
591d37e266 | ||
|
|
d1f02e8f49 | ||
|
|
9faab59b23 |
14
README.md
14
README.md
@@ -732,8 +732,8 @@ clip = CLIP(
|
|||||||
|
|
||||||
# mock data
|
# mock data
|
||||||
|
|
||||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
text = torch.randint(0, 49408, (32, 256)).cuda()
|
||||||
images = torch.randn(4, 3, 256, 256).cuda()
|
images = torch.randn(32, 3, 256, 256).cuda()
|
||||||
|
|
||||||
# decoder (with unet)
|
# decoder (with unet)
|
||||||
|
|
||||||
@@ -774,8 +774,12 @@ decoder_trainer = DecoderTrainer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for unet_number in (1, 2):
|
for unet_number in (1, 2):
|
||||||
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
|
loss = decoder_trainer(
|
||||||
loss.backward()
|
images,
|
||||||
|
text = text,
|
||||||
|
unet_number = unet_number, # which unet to train on
|
||||||
|
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
|
||||||
|
)
|
||||||
|
|
||||||
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||||
|
|
||||||
@@ -839,7 +843,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
loss = diffusion_prior_trainer(text, images)
|
loss = diffusion_prior_trainer(text, images)
|
||||||
loss.backward()
|
|
||||||
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
||||||
|
|
||||||
# after much of the above three lines in a loop
|
# after much of the above three lines in a loop
|
||||||
@@ -1017,6 +1020,7 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||||
- [ ] decoder needs one day worth of refactor for tech debt
|
- [ ] decoder needs one day worth of refactor for tech debt
|
||||||
|
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
|
|||||||
@@ -614,7 +614,6 @@ class Attention(nn.Module):
|
|||||||
heads = 8,
|
heads = 8,
|
||||||
dropout = 0.,
|
dropout = 0.,
|
||||||
causal = False,
|
causal = False,
|
||||||
post_norm = False,
|
|
||||||
rotary_emb = None
|
rotary_emb = None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -624,7 +623,6 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.norm = LayerNorm(dim)
|
self.norm = LayerNorm(dim)
|
||||||
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||||
@@ -635,7 +633,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
nn.Linear(inner_dim, dim, bias = False),
|
nn.Linear(inner_dim, dim, bias = False),
|
||||||
LayerNorm(dim) if post_norm else nn.Identity()
|
LayerNorm(dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, mask = None, attn_bias = None):
|
def forward(self, x, mask = None, attn_bias = None):
|
||||||
@@ -692,8 +690,7 @@ class Attention(nn.Module):
|
|||||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||||
|
|
||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
out = self.to_out(out)
|
return self.to_out(out)
|
||||||
return self.post_norm(out)
|
|
||||||
|
|
||||||
class CausalTransformer(nn.Module):
|
class CausalTransformer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -719,7 +716,7 @@ class CausalTransformer(nn.Module):
|
|||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
for _ in range(depth):
|
for _ in range(depth):
|
||||||
self.layers.append(nn.ModuleList([
|
self.layers.append(nn.ModuleList([
|
||||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
|
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
|
||||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
@@ -1166,6 +1163,7 @@ class CrossAttention(nn.Module):
|
|||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
dropout = 0.,
|
dropout = 0.,
|
||||||
|
norm_context = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
@@ -1175,13 +1173,17 @@ class CrossAttention(nn.Module):
|
|||||||
context_dim = default(context_dim, dim)
|
context_dim = default(context_dim, dim)
|
||||||
|
|
||||||
self.norm = LayerNorm(dim)
|
self.norm = LayerNorm(dim)
|
||||||
self.norm_context = LayerNorm(context_dim)
|
self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
||||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Linear(inner_dim, dim, bias = False),
|
||||||
|
LayerNorm(dim)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, context, mask = None):
|
def forward(self, x, context, mask = None):
|
||||||
b, n, device = *x.shape[:2], x.device
|
b, n, device = *x.shape[:2], x.device
|
||||||
@@ -1377,6 +1379,9 @@ class Unet(nn.Module):
|
|||||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||||
) if image_embed_dim != cond_dim else nn.Identity()
|
) if image_embed_dim != cond_dim else nn.Identity()
|
||||||
|
|
||||||
|
self.norm_cond = nn.LayerNorm(cond_dim)
|
||||||
|
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
||||||
|
|
||||||
# text encoding conditioning (optional)
|
# text encoding conditioning (optional)
|
||||||
|
|
||||||
self.text_to_cond = None
|
self.text_to_cond = None
|
||||||
@@ -1592,6 +1597,11 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
|
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
|
||||||
|
|
||||||
|
# normalize conditioning tokens
|
||||||
|
|
||||||
|
c = self.norm_cond(c)
|
||||||
|
mid_c = self.norm_mid_cond(mid_c)
|
||||||
|
|
||||||
# go through the layers of the unet, down and up
|
# go through the layers of the unet, down and up
|
||||||
|
|
||||||
hiddens = []
|
hiddens = []
|
||||||
|
|||||||
@@ -7,16 +7,17 @@ def separate_weight_decayable_params(params):
|
|||||||
|
|
||||||
def get_optimizer(
|
def get_optimizer(
|
||||||
params,
|
params,
|
||||||
lr = 3e-4,
|
lr = 2e-5,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
betas = (0.9, 0.999),
|
betas = (0.9, 0.999),
|
||||||
|
eps = 1e-8,
|
||||||
filter_by_requires_grad = False
|
filter_by_requires_grad = False
|
||||||
):
|
):
|
||||||
if filter_by_requires_grad:
|
if filter_by_requires_grad:
|
||||||
params = list(filter(lambda t: t.requires_grad, params))
|
params = list(filter(lambda t: t.requires_grad, params))
|
||||||
|
|
||||||
if wd == 0:
|
if wd == 0:
|
||||||
return Adam(params, lr = lr, betas = betas)
|
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||||
|
|
||||||
params = set(params)
|
params = set(params)
|
||||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||||
@@ -26,4 +27,4 @@ def get_optimizer(
|
|||||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||||
]
|
]
|
||||||
|
|
||||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
|
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
|
from math import ceil
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -14,6 +16,9 @@ from dalle2_pytorch.optimizer import get_optimizer
|
|||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
def cast_tuple(val, length = 1):
|
def cast_tuple(val, length = 1):
|
||||||
return val if isinstance(val, tuple) else ((val,) * length)
|
return val if isinstance(val, tuple) else ((val,) * length)
|
||||||
|
|
||||||
@@ -40,6 +45,47 @@ def groupby_prefix_and_trim(prefix, d):
|
|||||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||||
return kwargs_without_prefix, kwargs
|
return kwargs_without_prefix, kwargs
|
||||||
|
|
||||||
|
# gradient accumulation functions
|
||||||
|
|
||||||
|
def split_iterable(it, split_size):
|
||||||
|
accum = []
|
||||||
|
for ind in range(ceil(len(it) / split_size)):
|
||||||
|
start_index = ind * split_size
|
||||||
|
accum.append(it[start_index: (start_index + split_size)])
|
||||||
|
return accum
|
||||||
|
|
||||||
|
def split(t, split_size = None):
|
||||||
|
if not exists(split_size):
|
||||||
|
return t
|
||||||
|
|
||||||
|
if isinstance(t, torch.Tensor):
|
||||||
|
return t.split(split_size, dim = 0)
|
||||||
|
|
||||||
|
if isinstance(t, Iterable):
|
||||||
|
return split_iterable(t, split_size)
|
||||||
|
|
||||||
|
return TypeError
|
||||||
|
|
||||||
|
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
|
||||||
|
batch_size = len(x)
|
||||||
|
split_size = default(split_size, batch_size)
|
||||||
|
chunk_size = ceil(batch_size / split_size)
|
||||||
|
|
||||||
|
dict_len = len(kwargs)
|
||||||
|
dict_keys = kwargs.keys()
|
||||||
|
all_args = (x, *args, *kwargs.values())
|
||||||
|
len_all_args = len(all_args)
|
||||||
|
split_kwargs_index = len_all_args - dict_len
|
||||||
|
|
||||||
|
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
|
||||||
|
chunk_sizes = tuple(map(len, split_all_args[0]))
|
||||||
|
|
||||||
|
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
|
||||||
|
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
|
||||||
|
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
|
||||||
|
chunk_size_frac = chunk_size / batch_size
|
||||||
|
yield chunk_size_frac, (chunked_args, chunked_kwargs)
|
||||||
|
|
||||||
# print helpers
|
# print helpers
|
||||||
|
|
||||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||||
@@ -90,7 +136,7 @@ class EMA(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
beta = 0.99,
|
beta = 0.9999,
|
||||||
update_after_step = 1000,
|
update_after_step = 1000,
|
||||||
update_every = 10,
|
update_every = 10,
|
||||||
):
|
):
|
||||||
@@ -147,6 +193,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
use_ema = True,
|
use_ema = True,
|
||||||
lr = 3e-4,
|
lr = 3e-4,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
|
eps = 1e-6,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
amp = False,
|
amp = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -173,6 +220,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
diffusion_prior.parameters(),
|
diffusion_prior.parameters(),
|
||||||
lr = lr,
|
lr = lr,
|
||||||
wd = wd,
|
wd = wd,
|
||||||
|
eps = eps,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -206,13 +254,22 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
x,
|
||||||
*args,
|
*args,
|
||||||
divisor = 1,
|
max_batch_size = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
with autocast(enabled = self.amp):
|
total_loss = 0.
|
||||||
loss = self.diffusion_prior(*args, **kwargs)
|
|
||||||
return self.scaler.scale(loss / divisor)
|
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs):
|
||||||
|
with autocast(enabled = self.amp):
|
||||||
|
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
|
||||||
|
|
||||||
|
loss = loss * chunk_size_frac
|
||||||
|
total_loss += loss.item()
|
||||||
|
self.scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
return total_loss
|
||||||
|
|
||||||
# decoder trainer
|
# decoder trainer
|
||||||
|
|
||||||
@@ -221,8 +278,9 @@ class DecoderTrainer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
decoder,
|
decoder,
|
||||||
use_ema = True,
|
use_ema = True,
|
||||||
lr = 3e-4,
|
lr = 2e-5,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
|
eps = 1e-8,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
amp = False,
|
amp = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -247,13 +305,14 @@ class DecoderTrainer(nn.Module):
|
|||||||
# be able to finely customize learning rate, weight decay
|
# be able to finely customize learning rate, weight decay
|
||||||
# per unet
|
# per unet
|
||||||
|
|
||||||
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
|
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||||
|
|
||||||
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
|
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
|
||||||
optimizer = get_optimizer(
|
optimizer = get_optimizer(
|
||||||
unet.parameters(),
|
unet.parameters(),
|
||||||
lr = unet_lr,
|
lr = unet_lr,
|
||||||
wd = unet_wd,
|
wd = unet_wd,
|
||||||
|
eps = unet_eps,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -321,9 +380,17 @@ class DecoderTrainer(nn.Module):
|
|||||||
x,
|
x,
|
||||||
*,
|
*,
|
||||||
unet_number,
|
unet_number,
|
||||||
divisor = 1,
|
max_batch_size = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
with autocast(enabled = self.amp):
|
total_loss = 0.
|
||||||
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
|
||||||
return self.scale(loss / divisor, unet_number = unet_number)
|
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
|
||||||
|
with autocast(enabled = self.amp):
|
||||||
|
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||||
|
|
||||||
|
loss = loss * chunk_size_frac
|
||||||
|
total_loss += loss.item()
|
||||||
|
self.scale(loss, unet_number = unet_number).backward()
|
||||||
|
|
||||||
|
return total_loss
|
||||||
|
|||||||
Reference in New Issue
Block a user