Compare commits

...

10 Commits

Author SHA1 Message Date
Phil Wang
1bd8a7835a attempting to fix issue with deepspeed fp16 seeing overflowing gradient 2022-07-06 08:27:34 -07:00
Phil Wang
f33453df9f debugging with Aidan 2022-07-05 18:22:43 -07:00
Phil Wang
1e4bb2bafb cast long as float before deriving sinusoidal pos emb 2022-07-05 18:01:22 -07:00
Phil Wang
ee75515c7d remove forcing of softmax in f32, in case it is interfering with deepspeed 2022-07-05 16:53:58 -07:00
Phil Wang
ec68243479 set ability to do warmup steps for each unet during training 2022-07-05 16:24:16 -07:00
Phil Wang
3afdcdfe86 need to keep track of training steps separately for each unet in decoder trainer 2022-07-05 15:17:59 -07:00
Phil Wang
b9a908ff75 bring in two tricks from the cogview paper for reducing the chances of overflow, for attention and layernorm 2022-07-05 14:27:04 -07:00
Phil Wang
e1fe3089df do bias-less layernorm manually 2022-07-05 13:09:58 -07:00
Phil Wang
6d477d7654 link to dalle2 laion 2022-07-05 11:43:07 -07:00
Phil Wang
531fe4b62f status 2022-07-05 10:46:55 -07:00
6 changed files with 103 additions and 27 deletions

View File

@@ -20,18 +20,20 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
<img src="./samples/oxford.png" width="600px" />
<img src="./samples/oxford.png" width="450px" />
*ongoing at 21k steps*
- <a href="https://twitter.com/Buntworthy/status/1529475416775434240?t=0GEge3Kr9I36cjcUVCQUTg">Justin Pinkney</a> successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
- <a href="https://github.com/rom1504">Romain</a> has scaled up training to 800 GPUs with the available scripts without any issues
## Pre-Trained Models
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/3d5rytsa?workspace=">Another test run with sparse attention</a>
- DALL-E 2 🚧
- DALL-E 2 🚧 - <a href="https://github.com/LAION-AI/dalle2-laion">DALL-E 2 Laion repository</a>
## Appreciation

View File

@@ -335,6 +335,10 @@ def approx_standard_normal_cdf(x):
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
assert x.shape == means.shape == log_scales.shape
# attempting to correct nan gradients when learned variance is turned on
# in the setting of deepspeed fp16
eps = 1e-12 if x.dtype == torch.float32 else 1e-5
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1. / 255.)
@@ -349,7 +353,7 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
log_cdf_plus,
torch.where(x > thres,
log_one_minus_cdf_min,
log(cdf_delta)))
log(cdf_delta, eps = eps)))
return log_probs
@@ -490,14 +494,16 @@ class NoiseScheduler(nn.Module):
# diffusion prior
class LayerNorm(nn.Module):
def __init__(self, dim):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
x = x / x.amax(dim = -1, keepdim = True).detach()
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
@@ -506,10 +512,10 @@ class ChanLayerNorm(nn.Module):
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
x = x / x.amax(dim = 1, keepdim = True).detach()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g
return (x - mean) * (var + self.eps).rsqrt() * self.g
class Residual(nn.Module):
def __init__(self, fn):
@@ -629,10 +635,13 @@ class Attention(nn.Module):
heads = 8,
dropout = 0.,
causal = False,
rotary_emb = None
rotary_emb = None,
pb_relax_alpha = 32 ** 2
):
super().__init__()
self.scale = dim_head ** -0.5
self.pb_relax_alpha = pb_relax_alpha
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
self.heads = heads
inner_dim = dim_head * heads
@@ -696,7 +705,10 @@ class Attention(nn.Module):
# attention
attn = sim.softmax(dim = -1, dtype = torch.float32)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate values
@@ -1119,11 +1131,12 @@ class SinusoidalPosEmb(nn.Module):
self.dim = dim
def forward(self, x):
dtype, device = x.dtype, x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
class Block(nn.Module):
def __init__(
@@ -1210,10 +1223,12 @@ class CrossAttention(nn.Module):
dim_head = 64,
heads = 8,
dropout = 0.,
norm_context = False
norm_context = False,
pb_relax_alpha = 32 ** 2
):
super().__init__()
self.scale = dim_head ** -0.5
self.pb_relax_alpha = pb_relax_alpha
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
self.heads = heads
inner_dim = dim_head * heads
@@ -1259,7 +1274,10 @@ class CrossAttention(nn.Module):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1, dtype = torch.float32)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
@@ -1613,6 +1631,7 @@ class Unet(nn.Module):
# time conditioning
time = time.type_as(x)
time_hiddens = self.to_time_hiddens(time)
time_tokens = self.to_time_tokens(time_hiddens)

View File

@@ -293,6 +293,7 @@ class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: SingularOrIterable(float) = 1e-4
wd: SingularOrIterable(float) = 0.01
warmup_steps: Optional[SingularOrIterable(int)] = None
find_unused_parameters: bool = True
max_grad_norm: SingularOrIterable(float) = 0.5
save_every_n_samples: int = 100000

View File

@@ -3,10 +3,13 @@ import copy
from pathlib import Path
from math import ceil
from functools import partial, wraps
from contextlib import nullcontext
from collections.abc import Iterable
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
@@ -14,6 +17,8 @@ from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.version import __version__
from packaging import version
import pytorch_warmup as warmup
from ema_pytorch import EMA
from accelerate import Accelerator
@@ -162,6 +167,7 @@ class DiffusionPriorTrainer(nn.Module):
group_wd_params = True,
device = None,
accelerator = None,
verbose = True,
**kwargs
):
super().__init__()
@@ -208,6 +214,10 @@ class DiffusionPriorTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
# verbosity
self.verbose = verbose
# track steps internally
self.register_buffer('step', torch.tensor([0]))
@@ -215,6 +225,9 @@ class DiffusionPriorTrainer(nn.Module):
# accelerator wrappers
def print(self, msg):
if not self.verbose:
return
if exists(self.accelerator):
self.accelerator.print(msg)
else:
@@ -428,6 +441,7 @@ class DecoderTrainer(nn.Module):
lr = 1e-4,
wd = 1e-2,
eps = 1e-8,
warmup_steps = None,
max_grad_norm = 0.5,
amp = False,
group_wd_params = True,
@@ -449,13 +463,15 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay
# per unet
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
optimizers = []
schedulers = []
warmup_schedulers = []
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
@@ -467,6 +483,13 @@ class DecoderTrainer(nn.Module):
optimizers.append(optimizer)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler)
schedulers.append(scheduler)
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
@@ -474,15 +497,27 @@ class DecoderTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
schedulers = list(self.accelerator.prepare(*schedulers))
self.decoder = decoder
# store optimizers
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
setattr(self, f'optim{opt_ind}', optimizer)
# store schedulers
for sched_ind, scheduler in zip(range(len(schedulers)), schedulers):
setattr(self, f'sched{sched_ind}', scheduler)
# store warmup schedulers
self.warmup_schedulers = warmup_schedulers
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
@@ -491,7 +526,7 @@ class DecoderTrainer(nn.Module):
save_obj = dict(
model = self.accelerator.unwrap_model(self.decoder).state_dict(),
version = __version__,
step = self.step.item(),
steps = self.steps.cpu(),
**kwargs
)
@@ -510,17 +545,22 @@ class DecoderTrainer(nn.Module):
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
self.steps.copy_(loaded_obj['steps'])
if only_model:
return loaded_obj
for ind in range(0, self.num_unets):
for ind, last_step in zip(range(0, self.num_unets), self.steps.tolist()):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
warmup_scheduler = self.warmup_schedulers[ind]
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
if exists(warmup_scheduler):
warmup_scheduler.last_step = last_step
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
@@ -539,6 +579,12 @@ class DecoderTrainer(nn.Module):
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def increment_step(self, unet_number):
assert 1 <= unet_number <= self.num_unets
unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device)
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
def update(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
@@ -547,17 +593,25 @@ class DecoderTrainer(nn.Module):
index = unet_number - 1
optimizer = getattr(self, f'optim{index}')
scheduler = getattr(self, f'sched{index}')
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
optimizer.step()
optimizer.zero_grad()
warmup_scheduler = self.warmup_schedulers[index]
scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext
with scheduler_context():
scheduler.step()
if self.use_ema:
ema_unet = self.ema_unets[index]
ema_unet.update()
self.step += 1
self.increment_step(unet_number)
@torch.no_grad()
@cast_torch_tensor
@@ -607,7 +661,6 @@ class DecoderTrainer(nn.Module):
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
# with autocast(enabled = self.amp):
with self.accelerator.autocast():
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac

View File

@@ -1 +1 @@
__version__ = '0.15.4'
__version__ = '0.16.10'

View File

@@ -37,6 +37,7 @@ setup(
'packaging',
'pillow',
'pydantic',
'pytorch-warmup',
'resize-right>=0.0.2',
'rotary-embedding-torch',
'torch>=1.10',