From cf95d37e98611a17c064e8fb67029033bffdd6ce Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 5 Jul 2022 16:20:49 -0700 Subject: [PATCH] set ability to do warmup steps for each unet during training --- dalle2_pytorch/train_configs.py | 1 + dalle2_pytorch/trainer.py | 54 ++++++++++++++++++++++++++++++--- dalle2_pytorch/version.py | 2 +- setup.py | 1 + 4 files changed, 53 insertions(+), 5 deletions(-) diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 8e7560f..a016981 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -293,6 +293,7 @@ class DecoderTrainConfig(BaseModel): epochs: int = 20 lr: SingularOrIterable(float) = 1e-4 wd: SingularOrIterable(float) = 0.01 + warmup_steps: Optional[SingularOrIterable(int)] = None find_unused_parameters: bool = True max_grad_norm: SingularOrIterable(float) = 0.5 save_every_n_samples: int = 100000 diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 4afb6a4..8b88f92 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -3,11 +3,13 @@ import copy from pathlib import Path from math import ceil from functools import partial, wraps +from contextlib import nullcontext from collections.abc import Iterable import torch import torch.nn.functional as F from torch import nn +from torch.optim.lr_scheduler import LambdaLR from torch.cuda.amp import autocast, GradScaler from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior @@ -15,6 +17,8 @@ from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.version import __version__ from packaging import version +import pytorch_warmup as warmup + from ema_pytorch import EMA from accelerate import Accelerator @@ -163,6 +167,7 @@ class DiffusionPriorTrainer(nn.Module): group_wd_params = True, device = None, accelerator = None, + verbose = True, **kwargs ): super().__init__() @@ -209,6 +214,10 @@ class DiffusionPriorTrainer(nn.Module): self.max_grad_norm = max_grad_norm + # verbosity + + self.verbose = verbose + # track steps internally self.register_buffer('step', torch.tensor([0])) @@ -216,6 +225,9 @@ class DiffusionPriorTrainer(nn.Module): # accelerator wrappers def print(self, msg): + if not self.verbose: + return + if exists(self.accelerator): self.accelerator.print(msg) else: @@ -429,6 +441,7 @@ class DecoderTrainer(nn.Module): lr = 1e-4, wd = 1e-2, eps = 1e-8, + warmup_steps = None, max_grad_norm = 0.5, amp = False, group_wd_params = True, @@ -450,13 +463,15 @@ class DecoderTrainer(nn.Module): # be able to finely customize learning rate, weight decay # per unet - lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps)) + lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps)) assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4' optimizers = [] + schedulers = [] + warmup_schedulers = [] - for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps): + for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps): optimizer = get_optimizer( unet.parameters(), lr = unet_lr, @@ -468,6 +483,13 @@ class DecoderTrainer(nn.Module): optimizers.append(optimizer) + scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) + + warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None + warmup_schedulers.append(warmup_scheduler) + + schedulers.append(scheduler) + if self.use_ema: self.ema_unets.append(EMA(unet, **ema_kwargs)) @@ -478,12 +500,24 @@ class DecoderTrainer(nn.Module): self.register_buffer('steps', torch.tensor([0] * self.num_unets)) decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers)) + schedulers = list(self.accelerator.prepare(*schedulers)) self.decoder = decoder + # store optimizers + for opt_ind, optimizer in zip(range(len(optimizers)), optimizers): setattr(self, f'optim{opt_ind}', optimizer) + # store schedulers + + for sched_ind, scheduler in zip(range(len(schedulers)), schedulers): + setattr(self, f'sched{sched_ind}', scheduler) + + # store warmup schedulers + + self.warmup_schedulers = warmup_schedulers + def save(self, path, overwrite = True, **kwargs): path = Path(path) assert not (path.exists() and not overwrite) @@ -516,12 +550,17 @@ class DecoderTrainer(nn.Module): if only_model: return loaded_obj - for ind in range(0, self.num_unets): + for ind, last_step in zip(range(0, self.num_unets), self.steps.cpu().unbind()): + optimizer_key = f'optim{ind}' optimizer = getattr(self, optimizer_key) + warmup_scheduler = self.warmup_schedulers[ind] self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key]) + if exists(warmup_scheduler): + warmup_scheduler.last_step = last_step + if self.use_ema: assert 'ema' in loaded_obj self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) @@ -554,12 +593,20 @@ class DecoderTrainer(nn.Module): index = unet_number - 1 optimizer = getattr(self, f'optim{index}') + scheduler = getattr(self, f'sched{index}') if exists(self.max_grad_norm): self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients + optimizer.step() optimizer.zero_grad() + warmup_scheduler = self.warmup_schedulers[index] + scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext + + with scheduler_context(): + scheduler.step() + if self.use_ema: ema_unet = self.ema_unets[index] ema_unet.update() @@ -614,7 +661,6 @@ class DecoderTrainer(nn.Module): total_loss = 0. for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): - # with autocast(enabled = self.amp): with self.accelerator.autocast(): loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs) loss = loss * chunk_size_frac diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 66e314a..1494edd 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.3' +__version__ = '0.16.5' diff --git a/setup.py b/setup.py index 4f2a251..26491f7 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,7 @@ setup( 'packaging', 'pillow', 'pydantic', + 'pytorch-warmup', 'resize-right>=0.0.2', 'rotary-embedding-torch', 'torch>=1.10',