diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index db44c71..f154512 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -9,7 +9,7 @@ from collections.abc import Iterable import torch import torch.nn.functional as F from torch import nn -from torch.optim.lr_scheduler import LambdaLR +from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR from torch.cuda.amp import autocast, GradScaler from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior @@ -433,6 +433,7 @@ class DecoderTrainer(nn.Module): wd = 1e-2, eps = 1e-8, warmup_steps = None, + cosine_decay_max_steps = None, max_grad_norm = 0.5, amp = False, group_wd_params = True, @@ -454,7 +455,7 @@ class DecoderTrainer(nn.Module): # be able to finely customize learning rate, weight decay # per unet - lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps)) + lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps)) assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4' @@ -462,7 +463,7 @@ class DecoderTrainer(nn.Module): schedulers = [] warmup_schedulers = [] - for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps): + for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps): if isinstance(unet, nn.Identity): optimizers.append(None) schedulers.append(None) @@ -478,7 +479,11 @@ class DecoderTrainer(nn.Module): ) optimizers.append(optimizer) - scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) + + if exists(unet_cosine_decay_max_steps): + scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps) + else: + scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None warmup_schedulers.append(warmup_scheduler) @@ -558,9 +563,15 @@ class DecoderTrainer(nn.Module): for ind in range(0, self.num_unets): optimizer_key = f'optim{ind}' + scheduler_key = f'sched{ind}' + optimizer = getattr(self, optimizer_key) - state_dict = optimizer.state_dict() if optimizer is not None else None - save_obj = {**save_obj, optimizer_key: state_dict} + scheduler = getattr(self, scheduler_key) + + optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None + scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None + + save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict} if self.use_ema: save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} @@ -581,10 +592,18 @@ class DecoderTrainer(nn.Module): optimizer_key = f'optim{ind}' optimizer = getattr(self, optimizer_key) + + scheduler_key = f'sched{ind}' + scheduler = getattr(self, scheduler_key) + warmup_scheduler = self.warmup_schedulers[ind] - if optimizer is not None: + + if exists(optimizer): optimizer.load_state_dict(loaded_obj[optimizer_key]) + if exists(scheduler): + scheduler.load_state_dict(loaded_obj[scheduler_key]) + if exists(warmup_scheduler): warmup_scheduler.last_step = last_step diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 0e1a38d..b280975 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.7.0' +__version__ = '1.8.0'