mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
add cosine annealing lr schedule
This commit is contained in:
@@ -9,7 +9,7 @@ from collections.abc import Iterable
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
@@ -433,6 +433,7 @@ class DecoderTrainer(nn.Module):
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
warmup_steps = None,
|
||||
cosine_decay_max_steps = None,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
@@ -454,7 +455,7 @@ class DecoderTrainer(nn.Module):
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
|
||||
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
|
||||
|
||||
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
|
||||
@@ -462,7 +463,7 @@ class DecoderTrainer(nn.Module):
|
||||
schedulers = []
|
||||
warmup_schedulers = []
|
||||
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
|
||||
if isinstance(unet, nn.Identity):
|
||||
optimizers.append(None)
|
||||
schedulers.append(None)
|
||||
@@ -478,7 +479,11 @@ class DecoderTrainer(nn.Module):
|
||||
)
|
||||
|
||||
optimizers.append(optimizer)
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
if exists(unet_cosine_decay_max_steps):
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
|
||||
else:
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||
warmup_schedulers.append(warmup_scheduler)
|
||||
@@ -558,9 +563,15 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
optimizer_key = f'optim{ind}'
|
||||
scheduler_key = f'sched{ind}'
|
||||
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
state_dict = optimizer.state_dict() if optimizer is not None else None
|
||||
save_obj = {**save_obj, optimizer_key: state_dict}
|
||||
scheduler = getattr(self, scheduler_key)
|
||||
|
||||
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
|
||||
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
|
||||
|
||||
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||
@@ -581,10 +592,18 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
|
||||
scheduler_key = f'sched{ind}'
|
||||
scheduler = getattr(self, scheduler_key)
|
||||
|
||||
warmup_scheduler = self.warmup_schedulers[ind]
|
||||
if optimizer is not None:
|
||||
|
||||
if exists(optimizer):
|
||||
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if exists(scheduler):
|
||||
scheduler.load_state_dict(loaded_obj[scheduler_key])
|
||||
|
||||
if exists(warmup_scheduler):
|
||||
warmup_scheduler.last_step = last_step
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.7.0'
|
||||
__version__ = '1.8.0'
|
||||
|
||||
Reference in New Issue
Block a user