set ability to do warmup steps for each unet during training

This commit is contained in:
Phil Wang
2022-07-05 16:20:49 -07:00
parent 3afdcdfe86
commit cf95d37e98
4 changed files with 53 additions and 5 deletions

View File

@@ -293,6 +293,7 @@ class DecoderTrainConfig(BaseModel):
epochs: int = 20 epochs: int = 20
lr: SingularOrIterable(float) = 1e-4 lr: SingularOrIterable(float) = 1e-4
wd: SingularOrIterable(float) = 0.01 wd: SingularOrIterable(float) = 0.01
warmup_steps: Optional[SingularOrIterable(int)] = None
find_unused_parameters: bool = True find_unused_parameters: bool = True
max_grad_norm: SingularOrIterable(float) = 0.5 max_grad_norm: SingularOrIterable(float) = 0.5
save_every_n_samples: int = 100000 save_every_n_samples: int = 100000

View File

@@ -3,11 +3,13 @@ import copy
from pathlib import Path from pathlib import Path
from math import ceil from math import ceil
from functools import partial, wraps from functools import partial, wraps
from contextlib import nullcontext
from collections.abc import Iterable from collections.abc import Iterable
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.cuda.amp import autocast, GradScaler from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
@@ -15,6 +17,8 @@ from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.version import __version__ from dalle2_pytorch.version import __version__
from packaging import version from packaging import version
import pytorch_warmup as warmup
from ema_pytorch import EMA from ema_pytorch import EMA
from accelerate import Accelerator from accelerate import Accelerator
@@ -163,6 +167,7 @@ class DiffusionPriorTrainer(nn.Module):
group_wd_params = True, group_wd_params = True,
device = None, device = None,
accelerator = None, accelerator = None,
verbose = True,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -209,6 +214,10 @@ class DiffusionPriorTrainer(nn.Module):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
# verbosity
self.verbose = verbose
# track steps internally # track steps internally
self.register_buffer('step', torch.tensor([0])) self.register_buffer('step', torch.tensor([0]))
@@ -216,6 +225,9 @@ class DiffusionPriorTrainer(nn.Module):
# accelerator wrappers # accelerator wrappers
def print(self, msg): def print(self, msg):
if not self.verbose:
return
if exists(self.accelerator): if exists(self.accelerator):
self.accelerator.print(msg) self.accelerator.print(msg)
else: else:
@@ -429,6 +441,7 @@ class DecoderTrainer(nn.Module):
lr = 1e-4, lr = 1e-4,
wd = 1e-2, wd = 1e-2,
eps = 1e-8, eps = 1e-8,
warmup_steps = None,
max_grad_norm = 0.5, max_grad_norm = 0.5,
amp = False, amp = False,
group_wd_params = True, group_wd_params = True,
@@ -450,13 +463,15 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay # be able to finely customize learning rate, weight decay
# per unet # per unet
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps)) lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4' assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
optimizers = [] optimizers = []
schedulers = []
warmup_schedulers = []
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps): for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
optimizer = get_optimizer( optimizer = get_optimizer(
unet.parameters(), unet.parameters(),
lr = unet_lr, lr = unet_lr,
@@ -468,6 +483,13 @@ class DecoderTrainer(nn.Module):
optimizers.append(optimizer) optimizers.append(optimizer)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler)
schedulers.append(scheduler)
if self.use_ema: if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs)) self.ema_unets.append(EMA(unet, **ema_kwargs))
@@ -478,12 +500,24 @@ class DecoderTrainer(nn.Module):
self.register_buffer('steps', torch.tensor([0] * self.num_unets)) self.register_buffer('steps', torch.tensor([0] * self.num_unets))
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers)) decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
schedulers = list(self.accelerator.prepare(*schedulers))
self.decoder = decoder self.decoder = decoder
# store optimizers
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers): for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
setattr(self, f'optim{opt_ind}', optimizer) setattr(self, f'optim{opt_ind}', optimizer)
# store schedulers
for sched_ind, scheduler in zip(range(len(schedulers)), schedulers):
setattr(self, f'sched{sched_ind}', scheduler)
# store warmup schedulers
self.warmup_schedulers = warmup_schedulers
def save(self, path, overwrite = True, **kwargs): def save(self, path, overwrite = True, **kwargs):
path = Path(path) path = Path(path)
assert not (path.exists() and not overwrite) assert not (path.exists() and not overwrite)
@@ -516,12 +550,17 @@ class DecoderTrainer(nn.Module):
if only_model: if only_model:
return loaded_obj return loaded_obj
for ind in range(0, self.num_unets): for ind, last_step in zip(range(0, self.num_unets), self.steps.cpu().unbind()):
optimizer_key = f'optim{ind}' optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key) optimizer = getattr(self, optimizer_key)
warmup_scheduler = self.warmup_schedulers[ind]
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key]) self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
if exists(warmup_scheduler):
warmup_scheduler.last_step = last_step
if self.use_ema: if self.use_ema:
assert 'ema' in loaded_obj assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
@@ -554,12 +593,20 @@ class DecoderTrainer(nn.Module):
index = unet_number - 1 index = unet_number - 1
optimizer = getattr(self, f'optim{index}') optimizer = getattr(self, f'optim{index}')
scheduler = getattr(self, f'sched{index}')
if exists(self.max_grad_norm): if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
warmup_scheduler = self.warmup_schedulers[index]
scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext
with scheduler_context():
scheduler.step()
if self.use_ema: if self.use_ema:
ema_unet = self.ema_unets[index] ema_unet = self.ema_unets[index]
ema_unet.update() ema_unet.update()
@@ -614,7 +661,6 @@ class DecoderTrainer(nn.Module):
total_loss = 0. total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
# with autocast(enabled = self.amp):
with self.accelerator.autocast(): with self.accelerator.autocast():
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs) loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac loss = loss * chunk_size_frac

View File

@@ -1 +1 @@
__version__ = '0.16.3' __version__ = '0.16.5'

View File

@@ -37,6 +37,7 @@ setup(
'packaging', 'packaging',
'pillow', 'pillow',
'pydantic', 'pydantic',
'pytorch-warmup',
'resize-right>=0.0.2', 'resize-right>=0.0.2',
'rotary-embedding-torch', 'rotary-embedding-torch',
'torch>=1.10', 'torch>=1.10',