Compare commits

..

3 Commits
1.7.0 ... 1.8.1

Author SHA1 Message Date
Phil Wang
8f38339c2b give diffusion prior trainer cosine annealing lr too 2022-08-15 07:38:01 -07:00
Phil Wang
6b9b4b9e5e add cosine annealing lr schedule 2022-08-15 07:29:56 -07:00
Phil Wang
44e09d5a4d add weight standardization behind feature flag, which may potentially work well with group norm 2022-08-14 11:34:45 -07:00
3 changed files with 36 additions and 14 deletions

View File

@@ -1456,10 +1456,6 @@ class WeightStandardizedConv2d(nn.Conv2d):
https://arxiv.org/abs/1903.10520
weight standardization purportedly works synergistically with group normalization
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3

View File

@@ -9,7 +9,7 @@ from collections.abc import Iterable
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
@@ -182,6 +182,7 @@ class DiffusionPriorTrainer(nn.Module):
max_grad_norm = None,
group_wd_params = True,
warmup_steps = 1,
cosine_decay_max_steps = None,
**kwargs
):
super().__init__()
@@ -233,8 +234,11 @@ class DiffusionPriorTrainer(nn.Module):
**self.optim_kwargs,
**kwargs
)
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
if exists(cosine_decay_max_steps):
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
else:
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
@@ -271,6 +275,7 @@ class DiffusionPriorTrainer(nn.Module):
# FIXME: LambdaLR can't be saved due to pickling issues
save_obj = dict(
optimizer = self.optimizer.state_dict(),
scheduler = self.scheduler.state_dict(),
warmup_scheduler = self.warmup_scheduler,
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
version = version.parse(__version__),
@@ -317,7 +322,9 @@ class DiffusionPriorTrainer(nn.Module):
# unwrap the model when loading from checkpoint
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
self.optimizer.load_state_dict(loaded_obj['optimizer'])
self.scheduler.load_state_dict(loaded_obj['scheduler'])
# set warmupstep
if exists(self.warmup_scheduler):
@@ -433,6 +440,7 @@ class DecoderTrainer(nn.Module):
wd = 1e-2,
eps = 1e-8,
warmup_steps = None,
cosine_decay_max_steps = None,
max_grad_norm = 0.5,
amp = False,
group_wd_params = True,
@@ -454,7 +462,7 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay
# per unet
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
@@ -462,7 +470,7 @@ class DecoderTrainer(nn.Module):
schedulers = []
warmup_schedulers = []
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
if isinstance(unet, nn.Identity):
optimizers.append(None)
schedulers.append(None)
@@ -478,7 +486,11 @@ class DecoderTrainer(nn.Module):
)
optimizers.append(optimizer)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
if exists(unet_cosine_decay_max_steps):
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
else:
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler)
@@ -558,9 +570,15 @@ class DecoderTrainer(nn.Module):
for ind in range(0, self.num_unets):
optimizer_key = f'optim{ind}'
scheduler_key = f'sched{ind}'
optimizer = getattr(self, optimizer_key)
state_dict = optimizer.state_dict() if optimizer is not None else None
save_obj = {**save_obj, optimizer_key: state_dict}
scheduler = getattr(self, scheduler_key)
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
@@ -581,10 +599,18 @@ class DecoderTrainer(nn.Module):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
scheduler_key = f'sched{ind}'
scheduler = getattr(self, scheduler_key)
warmup_scheduler = self.warmup_schedulers[ind]
if optimizer is not None:
if exists(optimizer):
optimizer.load_state_dict(loaded_obj[optimizer_key])
if exists(scheduler):
scheduler.load_state_dict(loaded_obj[scheduler_key])
if exists(warmup_scheduler):
warmup_scheduler.last_step = last_step

View File

@@ -1 +1 @@
__version__ = '1.7.0'
__version__ = '1.8.1'