set ability to do warmup steps for each unet during training

This commit is contained in:
Phil Wang
2022-07-05 16:20:49 -07:00
parent 3afdcdfe86
commit cf95d37e98
4 changed files with 53 additions and 5 deletions

View File

@@ -293,6 +293,7 @@ class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: SingularOrIterable(float) = 1e-4
wd: SingularOrIterable(float) = 0.01
warmup_steps: Optional[SingularOrIterable(int)] = None
find_unused_parameters: bool = True
max_grad_norm: SingularOrIterable(float) = 0.5
save_every_n_samples: int = 100000

View File

@@ -3,11 +3,13 @@ import copy
from pathlib import Path
from math import ceil
from functools import partial, wraps
from contextlib import nullcontext
from collections.abc import Iterable
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
@@ -15,6 +17,8 @@ from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.version import __version__
from packaging import version
import pytorch_warmup as warmup
from ema_pytorch import EMA
from accelerate import Accelerator
@@ -163,6 +167,7 @@ class DiffusionPriorTrainer(nn.Module):
group_wd_params = True,
device = None,
accelerator = None,
verbose = True,
**kwargs
):
super().__init__()
@@ -209,6 +214,10 @@ class DiffusionPriorTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
# verbosity
self.verbose = verbose
# track steps internally
self.register_buffer('step', torch.tensor([0]))
@@ -216,6 +225,9 @@ class DiffusionPriorTrainer(nn.Module):
# accelerator wrappers
def print(self, msg):
if not self.verbose:
return
if exists(self.accelerator):
self.accelerator.print(msg)
else:
@@ -429,6 +441,7 @@ class DecoderTrainer(nn.Module):
lr = 1e-4,
wd = 1e-2,
eps = 1e-8,
warmup_steps = None,
max_grad_norm = 0.5,
amp = False,
group_wd_params = True,
@@ -450,13 +463,15 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay
# per unet
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
optimizers = []
schedulers = []
warmup_schedulers = []
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
@@ -468,6 +483,13 @@ class DecoderTrainer(nn.Module):
optimizers.append(optimizer)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler)
schedulers.append(scheduler)
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
@@ -478,12 +500,24 @@ class DecoderTrainer(nn.Module):
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
schedulers = list(self.accelerator.prepare(*schedulers))
self.decoder = decoder
# store optimizers
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
setattr(self, f'optim{opt_ind}', optimizer)
# store schedulers
for sched_ind, scheduler in zip(range(len(schedulers)), schedulers):
setattr(self, f'sched{sched_ind}', scheduler)
# store warmup schedulers
self.warmup_schedulers = warmup_schedulers
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
@@ -516,12 +550,17 @@ class DecoderTrainer(nn.Module):
if only_model:
return loaded_obj
for ind in range(0, self.num_unets):
for ind, last_step in zip(range(0, self.num_unets), self.steps.cpu().unbind()):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
warmup_scheduler = self.warmup_schedulers[ind]
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
if exists(warmup_scheduler):
warmup_scheduler.last_step = last_step
if self.use_ema:
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
@@ -554,12 +593,20 @@ class DecoderTrainer(nn.Module):
index = unet_number - 1
optimizer = getattr(self, f'optim{index}')
scheduler = getattr(self, f'sched{index}')
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
optimizer.step()
optimizer.zero_grad()
warmup_scheduler = self.warmup_schedulers[index]
scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext
with scheduler_context():
scheduler.step()
if self.use_ema:
ema_unet = self.ema_unets[index]
ema_unet.update()
@@ -614,7 +661,6 @@ class DecoderTrainer(nn.Module):
total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
# with autocast(enabled = self.amp):
with self.accelerator.autocast():
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac

View File

@@ -1 +1 @@
__version__ = '0.16.3'
__version__ = '0.16.5'

View File

@@ -37,6 +37,7 @@ setup(
'packaging',
'pillow',
'pydantic',
'pytorch-warmup',
'resize-right>=0.0.2',
'rotary-embedding-torch',
'torch>=1.10',