mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 10:14:19 +01:00
give diffusion prior trainer cosine annealing lr too
This commit is contained in:
@@ -182,6 +182,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
group_wd_params = True,
|
group_wd_params = True,
|
||||||
warmup_steps = 1,
|
warmup_steps = 1,
|
||||||
|
cosine_decay_max_steps = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -233,8 +234,11 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
**self.optim_kwargs,
|
**self.optim_kwargs,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
if exists(cosine_decay_max_steps):
|
||||||
|
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
|
||||||
|
else:
|
||||||
|
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
||||||
|
|
||||||
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
|
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
|
||||||
|
|
||||||
@@ -271,6 +275,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
# FIXME: LambdaLR can't be saved due to pickling issues
|
# FIXME: LambdaLR can't be saved due to pickling issues
|
||||||
save_obj = dict(
|
save_obj = dict(
|
||||||
optimizer = self.optimizer.state_dict(),
|
optimizer = self.optimizer.state_dict(),
|
||||||
|
scheduler = self.scheduler.state_dict(),
|
||||||
warmup_scheduler = self.warmup_scheduler,
|
warmup_scheduler = self.warmup_scheduler,
|
||||||
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
|
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
|
||||||
version = version.parse(__version__),
|
version = version.parse(__version__),
|
||||||
@@ -317,7 +322,9 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
# unwrap the model when loading from checkpoint
|
# unwrap the model when loading from checkpoint
|
||||||
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
|
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
|
||||||
|
|
||||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||||
|
self.scheduler.load_state_dict(loaded_obj['scheduler'])
|
||||||
|
|
||||||
# set warmupstep
|
# set warmupstep
|
||||||
if exists(self.warmup_scheduler):
|
if exists(self.warmup_scheduler):
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.8.0'
|
__version__ = '1.8.1'
|
||||||
|
|||||||
Reference in New Issue
Block a user