mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0cd5f24b6 | ||
|
|
b494ed81d4 | ||
|
|
ff3474f05c | ||
|
|
d5293f19f1 | ||
|
|
e697183849 | ||
|
|
591d37e266 |
14
README.md
14
README.md
@@ -732,8 +732,8 @@ clip = CLIP(
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
text = torch.randint(0, 49408, (32, 256)).cuda()
|
||||
images = torch.randn(32, 3, 256, 256).cuda()
|
||||
|
||||
# decoder (with unet)
|
||||
|
||||
@@ -774,8 +774,12 @@ decoder_trainer = DecoderTrainer(
|
||||
)
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
|
||||
loss.backward()
|
||||
loss = decoder_trainer(
|
||||
images,
|
||||
text = text,
|
||||
unet_number = unet_number, # which unet to train on
|
||||
max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
|
||||
)
|
||||
|
||||
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||
|
||||
@@ -839,7 +843,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
|
||||
)
|
||||
|
||||
loss = diffusion_prior_trainer(text, images)
|
||||
loss.backward()
|
||||
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
||||
|
||||
# after much of the above three lines in a loop
|
||||
@@ -1017,6 +1020,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
- [ ] decoder needs one day worth of refactor for tech debt
|
||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
@@ -1163,6 +1163,7 @@ class CrossAttention(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
norm_context = False
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
@@ -1172,7 +1173,7 @@ class CrossAttention(nn.Module):
|
||||
context_dim = default(context_dim, dim)
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
self.norm_context = LayerNorm(context_dim)
|
||||
self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
@@ -1378,6 +1379,9 @@ class Unet(nn.Module):
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
self.norm_cond = nn.LayerNorm(cond_dim)
|
||||
self.norm_mid_cond = nn.LayerNorm(cond_dim)
|
||||
|
||||
# text encoding conditioning (optional)
|
||||
|
||||
self.text_to_cond = None
|
||||
@@ -1593,6 +1597,11 @@ class Unet(nn.Module):
|
||||
|
||||
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
|
||||
|
||||
# normalize conditioning tokens
|
||||
|
||||
c = self.norm_cond(c)
|
||||
mid_c = self.norm_mid_cond(mid_c)
|
||||
|
||||
# go through the layers of the unet, down and up
|
||||
|
||||
hiddens = []
|
||||
|
||||
@@ -7,16 +7,17 @@ def separate_weight_decayable_params(params):
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
lr = 3e-4,
|
||||
lr = 2e-5,
|
||||
wd = 1e-2,
|
||||
betas = (0.9, 0.999),
|
||||
eps = 1e-8,
|
||||
filter_by_requires_grad = False
|
||||
):
|
||||
if filter_by_requires_grad:
|
||||
params = list(filter(lambda t: t.requires_grad, params))
|
||||
|
||||
if wd == 0:
|
||||
return Adam(params, lr = lr, betas = betas)
|
||||
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||
|
||||
params = set(params)
|
||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||
@@ -26,4 +27,4 @@ def get_optimizer(
|
||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||
]
|
||||
|
||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
|
||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import time
|
||||
import copy
|
||||
from math import ceil
|
||||
from functools import partial
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -14,6 +16,9 @@ from dalle2_pytorch.optimizer import get_optimizer
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
@@ -40,6 +45,46 @@ def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
# gradient accumulation functions
|
||||
|
||||
def split_iterable(it, split_size):
|
||||
accum = []
|
||||
for ind in range(ceil(len(it) / split_size)):
|
||||
start_index = ind * split_size
|
||||
accum.append(it[start_index: (start_index + split_size)])
|
||||
return accum
|
||||
|
||||
def split(t, split_size = None):
|
||||
if not exists(split_size):
|
||||
return t
|
||||
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.split(split_size, dim = 0)
|
||||
|
||||
if isinstance(t, Iterable):
|
||||
return split_iterable(t, split_size)
|
||||
|
||||
return TypeError
|
||||
|
||||
def split_args_and_kwargs(x, *args, split_size = None, **kwargs):
|
||||
batch_size = len(x)
|
||||
split_size = default(split_size, batch_size)
|
||||
chunk_size = ceil(batch_size / split_size)
|
||||
|
||||
dict_len = len(kwargs)
|
||||
dict_keys = kwargs.keys()
|
||||
all_args = (x, *args, *kwargs.values())
|
||||
len_all_args = len(all_args)
|
||||
split_index = len_all_args - dict_len
|
||||
|
||||
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * chunk_size) for arg in all_args]
|
||||
chunk_sizes = tuple(map(len, split_all_args[0]))
|
||||
|
||||
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
|
||||
chunked_args, chunked_kwargs_values = chunked_all_args[:split_index], chunked_all_args[split_index:]
|
||||
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
|
||||
yield chunk_size, (chunked_args, chunked_kwargs)
|
||||
|
||||
# print helpers
|
||||
|
||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||
@@ -90,7 +135,7 @@ class EMA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
beta = 0.99,
|
||||
beta = 0.9999,
|
||||
update_after_step = 1000,
|
||||
update_every = 10,
|
||||
):
|
||||
@@ -147,6 +192,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
use_ema = True,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-6,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
**kwargs
|
||||
@@ -173,6 +219,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
diffusion_prior.parameters(),
|
||||
lr = lr,
|
||||
wd = wd,
|
||||
eps = eps,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -206,13 +253,25 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
*args,
|
||||
divisor = 1,
|
||||
max_batch_size = None,
|
||||
**kwargs
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.diffusion_prior(*args, **kwargs)
|
||||
return self.scaler.scale(loss / divisor)
|
||||
batch_size = x.shape[0]
|
||||
total_samples = 0
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
|
||||
|
||||
total_loss += loss.item() * chunk_size
|
||||
total_samples += chunk_size
|
||||
|
||||
self.scaler.scale(loss * (chunk_size / batch_size)).backward()
|
||||
|
||||
return total_loss / total_samples
|
||||
|
||||
# decoder trainer
|
||||
|
||||
@@ -221,8 +280,9 @@ class DecoderTrainer(nn.Module):
|
||||
self,
|
||||
decoder,
|
||||
use_ema = True,
|
||||
lr = 3e-4,
|
||||
lr = 2e-5,
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
**kwargs
|
||||
@@ -247,13 +307,14 @@ class DecoderTrainer(nn.Module):
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
|
||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||
|
||||
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
|
||||
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
wd = unet_wd,
|
||||
eps = unet_eps,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -321,9 +382,20 @@ class DecoderTrainer(nn.Module):
|
||||
x,
|
||||
*,
|
||||
unet_number,
|
||||
divisor = 1,
|
||||
max_batch_size = None,
|
||||
**kwargs
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
||||
return self.scale(loss / divisor, unet_number = unet_number)
|
||||
batch_size = x.shape[0]
|
||||
total_samples = 0
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||
|
||||
total_loss += loss.item() * chunk_size
|
||||
total_samples += chunk_size
|
||||
|
||||
self.scale(loss * (chunk_size / batch_size), unet_number = unet_number).backward()
|
||||
|
||||
return total_loss / total_samples
|
||||
|
||||
Reference in New Issue
Block a user