mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
set ability to do warmup steps for each unet during training
This commit is contained in:
@@ -293,6 +293,7 @@ class DecoderTrainConfig(BaseModel):
|
|||||||
epochs: int = 20
|
epochs: int = 20
|
||||||
lr: SingularOrIterable(float) = 1e-4
|
lr: SingularOrIterable(float) = 1e-4
|
||||||
wd: SingularOrIterable(float) = 0.01
|
wd: SingularOrIterable(float) = 0.01
|
||||||
|
warmup_steps: Optional[SingularOrIterable(int)] = None
|
||||||
find_unused_parameters: bool = True
|
find_unused_parameters: bool = True
|
||||||
max_grad_norm: SingularOrIterable(float) = 0.5
|
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||||
save_every_n_samples: int = 100000
|
save_every_n_samples: int = 100000
|
||||||
|
|||||||
@@ -3,11 +3,13 @@ import copy
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
|
from contextlib import nullcontext
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
|
|
||||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||||
@@ -15,6 +17,8 @@ from dalle2_pytorch.optimizer import get_optimizer
|
|||||||
from dalle2_pytorch.version import __version__
|
from dalle2_pytorch.version import __version__
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
import pytorch_warmup as warmup
|
||||||
|
|
||||||
from ema_pytorch import EMA
|
from ema_pytorch import EMA
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
@@ -163,6 +167,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
group_wd_params = True,
|
group_wd_params = True,
|
||||||
device = None,
|
device = None,
|
||||||
accelerator = None,
|
accelerator = None,
|
||||||
|
verbose = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -209,6 +214,10 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
|
# verbosity
|
||||||
|
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
# track steps internally
|
# track steps internally
|
||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0]))
|
self.register_buffer('step', torch.tensor([0]))
|
||||||
@@ -216,6 +225,9 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
# accelerator wrappers
|
# accelerator wrappers
|
||||||
|
|
||||||
def print(self, msg):
|
def print(self, msg):
|
||||||
|
if not self.verbose:
|
||||||
|
return
|
||||||
|
|
||||||
if exists(self.accelerator):
|
if exists(self.accelerator):
|
||||||
self.accelerator.print(msg)
|
self.accelerator.print(msg)
|
||||||
else:
|
else:
|
||||||
@@ -429,6 +441,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
lr = 1e-4,
|
lr = 1e-4,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
|
warmup_steps = None,
|
||||||
max_grad_norm = 0.5,
|
max_grad_norm = 0.5,
|
||||||
amp = False,
|
amp = False,
|
||||||
group_wd_params = True,
|
group_wd_params = True,
|
||||||
@@ -450,13 +463,15 @@ class DecoderTrainer(nn.Module):
|
|||||||
# be able to finely customize learning rate, weight decay
|
# be able to finely customize learning rate, weight decay
|
||||||
# per unet
|
# per unet
|
||||||
|
|
||||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
|
||||||
|
|
||||||
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||||
|
|
||||||
optimizers = []
|
optimizers = []
|
||||||
|
schedulers = []
|
||||||
|
warmup_schedulers = []
|
||||||
|
|
||||||
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):
|
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
||||||
optimizer = get_optimizer(
|
optimizer = get_optimizer(
|
||||||
unet.parameters(),
|
unet.parameters(),
|
||||||
lr = unet_lr,
|
lr = unet_lr,
|
||||||
@@ -468,6 +483,13 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
optimizers.append(optimizer)
|
optimizers.append(optimizer)
|
||||||
|
|
||||||
|
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||||
|
|
||||||
|
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||||
|
warmup_schedulers.append(warmup_scheduler)
|
||||||
|
|
||||||
|
schedulers.append(scheduler)
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||||
|
|
||||||
@@ -478,12 +500,24 @@ class DecoderTrainer(nn.Module):
|
|||||||
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
|
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
|
||||||
|
|
||||||
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
|
||||||
|
schedulers = list(self.accelerator.prepare(*schedulers))
|
||||||
|
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
|
||||||
|
# store optimizers
|
||||||
|
|
||||||
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
|
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
|
||||||
setattr(self, f'optim{opt_ind}', optimizer)
|
setattr(self, f'optim{opt_ind}', optimizer)
|
||||||
|
|
||||||
|
# store schedulers
|
||||||
|
|
||||||
|
for sched_ind, scheduler in zip(range(len(schedulers)), schedulers):
|
||||||
|
setattr(self, f'sched{sched_ind}', scheduler)
|
||||||
|
|
||||||
|
# store warmup schedulers
|
||||||
|
|
||||||
|
self.warmup_schedulers = warmup_schedulers
|
||||||
|
|
||||||
def save(self, path, overwrite = True, **kwargs):
|
def save(self, path, overwrite = True, **kwargs):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
assert not (path.exists() and not overwrite)
|
assert not (path.exists() and not overwrite)
|
||||||
@@ -516,12 +550,17 @@ class DecoderTrainer(nn.Module):
|
|||||||
if only_model:
|
if only_model:
|
||||||
return loaded_obj
|
return loaded_obj
|
||||||
|
|
||||||
for ind in range(0, self.num_unets):
|
for ind, last_step in zip(range(0, self.num_unets), self.steps.cpu().unbind()):
|
||||||
|
|
||||||
optimizer_key = f'optim{ind}'
|
optimizer_key = f'optim{ind}'
|
||||||
optimizer = getattr(self, optimizer_key)
|
optimizer = getattr(self, optimizer_key)
|
||||||
|
warmup_scheduler = self.warmup_schedulers[ind]
|
||||||
|
|
||||||
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
|
self.accelerator.unwrap_model(optimizer).load_state_dict(loaded_obj[optimizer_key])
|
||||||
|
|
||||||
|
if exists(warmup_scheduler):
|
||||||
|
warmup_scheduler.last_step = last_step
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
assert 'ema' in loaded_obj
|
assert 'ema' in loaded_obj
|
||||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||||
@@ -554,12 +593,20 @@ class DecoderTrainer(nn.Module):
|
|||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
|
|
||||||
optimizer = getattr(self, f'optim{index}')
|
optimizer = getattr(self, f'optim{index}')
|
||||||
|
scheduler = getattr(self, f'sched{index}')
|
||||||
|
|
||||||
if exists(self.max_grad_norm):
|
if exists(self.max_grad_norm):
|
||||||
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
|
self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
warmup_scheduler = self.warmup_schedulers[index]
|
||||||
|
scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext
|
||||||
|
|
||||||
|
with scheduler_context():
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
ema_unet = self.ema_unets[index]
|
ema_unet = self.ema_unets[index]
|
||||||
ema_unet.update()
|
ema_unet.update()
|
||||||
@@ -614,7 +661,6 @@ class DecoderTrainer(nn.Module):
|
|||||||
total_loss = 0.
|
total_loss = 0.
|
||||||
|
|
||||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||||
# with autocast(enabled = self.amp):
|
|
||||||
with self.accelerator.autocast():
|
with self.accelerator.autocast():
|
||||||
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
|
||||||
loss = loss * chunk_size_frac
|
loss = loss * chunk_size_frac
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.16.3'
|
__version__ = '0.16.5'
|
||||||
|
|||||||
Reference in New Issue
Block a user