mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 02:04:19 +01:00
add cosine annealing lr schedule
This commit is contained in:
@@ -9,7 +9,7 @@ from collections.abc import Iterable
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
|
|
||||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||||
@@ -433,6 +433,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
warmup_steps = None,
|
warmup_steps = None,
|
||||||
|
cosine_decay_max_steps = None,
|
||||||
max_grad_norm = 0.5,
|
max_grad_norm = 0.5,
|
||||||
amp = False,
|
amp = False,
|
||||||
group_wd_params = True,
|
group_wd_params = True,
|
||||||
@@ -454,7 +455,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
# be able to finely customize learning rate, weight decay
|
# be able to finely customize learning rate, weight decay
|
||||||
# per unet
|
# per unet
|
||||||
|
|
||||||
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
|
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
|
||||||
|
|
||||||
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||||
|
|
||||||
@@ -462,7 +463,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
schedulers = []
|
schedulers = []
|
||||||
warmup_schedulers = []
|
warmup_schedulers = []
|
||||||
|
|
||||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
|
||||||
if isinstance(unet, nn.Identity):
|
if isinstance(unet, nn.Identity):
|
||||||
optimizers.append(None)
|
optimizers.append(None)
|
||||||
schedulers.append(None)
|
schedulers.append(None)
|
||||||
@@ -478,7 +479,11 @@ class DecoderTrainer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
optimizers.append(optimizer)
|
optimizers.append(optimizer)
|
||||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
|
||||||
|
if exists(unet_cosine_decay_max_steps):
|
||||||
|
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
|
||||||
|
else:
|
||||||
|
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||||
|
|
||||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||||
warmup_schedulers.append(warmup_scheduler)
|
warmup_schedulers.append(warmup_scheduler)
|
||||||
@@ -558,9 +563,15 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
for ind in range(0, self.num_unets):
|
for ind in range(0, self.num_unets):
|
||||||
optimizer_key = f'optim{ind}'
|
optimizer_key = f'optim{ind}'
|
||||||
|
scheduler_key = f'sched{ind}'
|
||||||
|
|
||||||
optimizer = getattr(self, optimizer_key)
|
optimizer = getattr(self, optimizer_key)
|
||||||
state_dict = optimizer.state_dict() if optimizer is not None else None
|
scheduler = getattr(self, scheduler_key)
|
||||||
save_obj = {**save_obj, optimizer_key: state_dict}
|
|
||||||
|
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
|
||||||
|
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
|
||||||
|
|
||||||
|
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||||
@@ -581,10 +592,18 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
optimizer_key = f'optim{ind}'
|
optimizer_key = f'optim{ind}'
|
||||||
optimizer = getattr(self, optimizer_key)
|
optimizer = getattr(self, optimizer_key)
|
||||||
|
|
||||||
|
scheduler_key = f'sched{ind}'
|
||||||
|
scheduler = getattr(self, scheduler_key)
|
||||||
|
|
||||||
warmup_scheduler = self.warmup_schedulers[ind]
|
warmup_scheduler = self.warmup_schedulers[ind]
|
||||||
if optimizer is not None:
|
|
||||||
|
if exists(optimizer):
|
||||||
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||||
|
|
||||||
|
if exists(scheduler):
|
||||||
|
scheduler.load_state_dict(loaded_obj[scheduler_key])
|
||||||
|
|
||||||
if exists(warmup_scheduler):
|
if exists(warmup_scheduler):
|
||||||
warmup_scheduler.last_step = last_step
|
warmup_scheduler.last_step = last_step
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.7.0'
|
__version__ = '1.8.0'
|
||||||
|
|||||||
Reference in New Issue
Block a user