|
|
|
|
@@ -9,7 +9,7 @@ from collections.abc import Iterable
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from torch import nn
|
|
|
|
|
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
|
|
|
|
|
from torch.optim.lr_scheduler import LambdaLR
|
|
|
|
|
from torch.cuda.amp import autocast, GradScaler
|
|
|
|
|
|
|
|
|
|
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
|
|
|
|
@@ -433,7 +433,6 @@ class DecoderTrainer(nn.Module):
|
|
|
|
|
wd = 1e-2,
|
|
|
|
|
eps = 1e-8,
|
|
|
|
|
warmup_steps = None,
|
|
|
|
|
cosine_decay_max_steps = None,
|
|
|
|
|
max_grad_norm = 0.5,
|
|
|
|
|
amp = False,
|
|
|
|
|
group_wd_params = True,
|
|
|
|
|
@@ -455,7 +454,7 @@ class DecoderTrainer(nn.Module):
|
|
|
|
|
# be able to finely customize learning rate, weight decay
|
|
|
|
|
# per unet
|
|
|
|
|
|
|
|
|
|
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
|
|
|
|
|
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
|
|
|
|
|
|
|
|
|
|
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
|
|
|
|
|
|
|
|
|
@@ -463,7 +462,7 @@ class DecoderTrainer(nn.Module):
|
|
|
|
|
schedulers = []
|
|
|
|
|
warmup_schedulers = []
|
|
|
|
|
|
|
|
|
|
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
|
|
|
|
|
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
|
|
|
|
if isinstance(unet, nn.Identity):
|
|
|
|
|
optimizers.append(None)
|
|
|
|
|
schedulers.append(None)
|
|
|
|
|
@@ -479,11 +478,7 @@ class DecoderTrainer(nn.Module):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
optimizers.append(optimizer)
|
|
|
|
|
|
|
|
|
|
if exists(unet_cosine_decay_max_steps):
|
|
|
|
|
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
|
|
|
|
|
else:
|
|
|
|
|
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
|
|
|
|
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
|
|
|
|
|
|
|
|
|
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
|
|
|
|
warmup_schedulers.append(warmup_scheduler)
|
|
|
|
|
@@ -563,15 +558,9 @@ class DecoderTrainer(nn.Module):
|
|
|
|
|
|
|
|
|
|
for ind in range(0, self.num_unets):
|
|
|
|
|
optimizer_key = f'optim{ind}'
|
|
|
|
|
scheduler_key = f'sched{ind}'
|
|
|
|
|
|
|
|
|
|
optimizer = getattr(self, optimizer_key)
|
|
|
|
|
scheduler = getattr(self, scheduler_key)
|
|
|
|
|
|
|
|
|
|
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
|
|
|
|
|
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
|
|
|
|
|
|
|
|
|
|
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
|
|
|
|
|
state_dict = optimizer.state_dict() if optimizer is not None else None
|
|
|
|
|
save_obj = {**save_obj, optimizer_key: state_dict}
|
|
|
|
|
|
|
|
|
|
if self.use_ema:
|
|
|
|
|
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
|
|
|
|
@@ -592,18 +581,10 @@ class DecoderTrainer(nn.Module):
|
|
|
|
|
|
|
|
|
|
optimizer_key = f'optim{ind}'
|
|
|
|
|
optimizer = getattr(self, optimizer_key)
|
|
|
|
|
|
|
|
|
|
scheduler_key = f'sched{ind}'
|
|
|
|
|
scheduler = getattr(self, scheduler_key)
|
|
|
|
|
|
|
|
|
|
warmup_scheduler = self.warmup_schedulers[ind]
|
|
|
|
|
|
|
|
|
|
if exists(optimizer):
|
|
|
|
|
if optimizer is not None:
|
|
|
|
|
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
|
|
|
|
|
|
|
|
|
if exists(scheduler):
|
|
|
|
|
scheduler.load_state_dict(loaded_obj[scheduler_key])
|
|
|
|
|
|
|
|
|
|
if exists(warmup_scheduler):
|
|
|
|
|
warmup_scheduler.last_step = last_step
|
|
|
|
|
|
|
|
|
|
|