|
|
|
|
@@ -9,7 +9,7 @@ from collections.abc import Iterable
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from torch import nn
|
|
|
|
|
from torch.optim.lr_scheduler import LambdaLR
|
|
|
|
|
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
|
|
|
|
|
from torch.cuda.amp import autocast, GradScaler
|
|
|
|
|
|
|
|
|
|
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
|
|
|
|
@@ -181,7 +181,8 @@ class DiffusionPriorTrainer(nn.Module):
|
|
|
|
|
eps = 1e-6,
|
|
|
|
|
max_grad_norm = None,
|
|
|
|
|
group_wd_params = True,
|
|
|
|
|
warmup_steps = 1,
|
|
|
|
|
warmup_steps = None,
|
|
|
|
|
cosine_decay_max_steps = None,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
@@ -233,8 +234,11 @@ class DiffusionPriorTrainer(nn.Module):
|
|
|
|
|
**self.optim_kwargs,
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
|
|
|
|
|
|
|
|
|
if exists(cosine_decay_max_steps):
|
|
|
|
|
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
|
|
|
|
|
else:
|
|
|
|
|
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
|
|
|
|
|
|
|
|
|
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
|
|
|
|
|
|
|
|
|
|
@@ -271,6 +275,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|
|
|
|
# FIXME: LambdaLR can't be saved due to pickling issues
|
|
|
|
|
save_obj = dict(
|
|
|
|
|
optimizer = self.optimizer.state_dict(),
|
|
|
|
|
scheduler = self.scheduler.state_dict(),
|
|
|
|
|
warmup_scheduler = self.warmup_scheduler,
|
|
|
|
|
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
|
|
|
|
|
version = version.parse(__version__),
|
|
|
|
|
@@ -317,7 +322,9 @@ class DiffusionPriorTrainer(nn.Module):
|
|
|
|
|
# unwrap the model when loading from checkpoint
|
|
|
|
|
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
|
|
|
|
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
|
|
|
|
|
|
|
|
|
|
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
|
|
|
|
self.scheduler.load_state_dict(loaded_obj['scheduler'])
|
|
|
|
|
|
|
|
|
|
# set warmupstep
|
|
|
|
|
if exists(self.warmup_scheduler):
|
|
|
|
|
@@ -350,7 +357,8 @@ class DiffusionPriorTrainer(nn.Module):
|
|
|
|
|
|
|
|
|
|
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
|
|
|
|
if not self.accelerator.optimizer_step_was_skipped:
|
|
|
|
|
with self.warmup_scheduler.dampening():
|
|
|
|
|
sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
|
|
|
|
|
with sched_context():
|
|
|
|
|
self.scheduler.step()
|
|
|
|
|
|
|
|
|
|
if self.use_ema:
|
|
|
|
|
@@ -433,6 +441,7 @@ class DecoderTrainer(nn.Module):
|
|
|
|
|
wd = 1e-2,
|
|
|
|
|
eps = 1e-8,
|
|
|
|
|
warmup_steps = None,
|
|
|
|
|
cosine_decay_max_steps = None,
|
|
|
|
|
max_grad_norm = 0.5,
|
|
|
|
|
amp = False,
|
|
|
|
|
group_wd_params = True,
|
|
|
|
|
@@ -454,7 +463,7 @@ class DecoderTrainer(nn.Module):
|
|
|
|
|
# be able to finely customize learning rate, weight decay
|
|
|
|
|
# per unet
|
|
|
|
|
|
|
|
|
|
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
|
|
|
|
|
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
|
|
|
|
|
|
|
|
|
|
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
|
|
|
|
|
|
|
|
|
@@ -462,7 +471,7 @@ class DecoderTrainer(nn.Module):
|
|
|
|
|
schedulers = []
|
|
|
|
|
warmup_schedulers = []
|
|
|
|
|
|
|
|
|
|
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
|
|
|
|
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
|
|
|
|
|
if isinstance(unet, nn.Identity):
|
|
|
|
|
optimizers.append(None)
|
|
|
|
|
schedulers.append(None)
|
|
|
|
|
@@ -478,7 +487,11 @@ class DecoderTrainer(nn.Module):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
optimizers.append(optimizer)
|
|
|
|
|
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
|
|
|
|
|
|
|
|
|
if exists(unet_cosine_decay_max_steps):
|
|
|
|
|
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
|
|
|
|
|
else:
|
|
|
|
|
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
|
|
|
|
|
|
|
|
|
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
|
|
|
|
warmup_schedulers.append(warmup_scheduler)
|
|
|
|
|
@@ -558,9 +571,15 @@ class DecoderTrainer(nn.Module):
|
|
|
|
|
|
|
|
|
|
for ind in range(0, self.num_unets):
|
|
|
|
|
optimizer_key = f'optim{ind}'
|
|
|
|
|
scheduler_key = f'sched{ind}'
|
|
|
|
|
|
|
|
|
|
optimizer = getattr(self, optimizer_key)
|
|
|
|
|
state_dict = optimizer.state_dict() if optimizer is not None else None
|
|
|
|
|
save_obj = {**save_obj, optimizer_key: state_dict}
|
|
|
|
|
scheduler = getattr(self, scheduler_key)
|
|
|
|
|
|
|
|
|
|
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
|
|
|
|
|
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
|
|
|
|
|
|
|
|
|
|
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
|
|
|
|
|
|
|
|
|
|
if self.use_ema:
|
|
|
|
|
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
|
|
|
|
@@ -581,10 +600,18 @@ class DecoderTrainer(nn.Module):
|
|
|
|
|
|
|
|
|
|
optimizer_key = f'optim{ind}'
|
|
|
|
|
optimizer = getattr(self, optimizer_key)
|
|
|
|
|
|
|
|
|
|
scheduler_key = f'sched{ind}'
|
|
|
|
|
scheduler = getattr(self, scheduler_key)
|
|
|
|
|
|
|
|
|
|
warmup_scheduler = self.warmup_schedulers[ind]
|
|
|
|
|
if optimizer is not None:
|
|
|
|
|
|
|
|
|
|
if exists(optimizer):
|
|
|
|
|
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
|
|
|
|
|
|
|
|
|
if exists(scheduler):
|
|
|
|
|
scheduler.load_state_dict(loaded_obj[scheduler_key])
|
|
|
|
|
|
|
|
|
|
if exists(warmup_scheduler):
|
|
|
|
|
warmup_scheduler.last_step = last_step
|
|
|
|
|
|
|
|
|
|
|