give diffusion prior trainer cosine annealing lr too

This commit is contained in:
Phil Wang
2022-08-15 07:38:01 -07:00
parent 6b9b4b9e5e
commit 8f38339c2b
2 changed files with 10 additions and 3 deletions

View File

@@ -182,6 +182,7 @@ class DiffusionPriorTrainer(nn.Module):
max_grad_norm = None, max_grad_norm = None,
group_wd_params = True, group_wd_params = True,
warmup_steps = 1, warmup_steps = 1,
cosine_decay_max_steps = None,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -234,6 +235,9 @@ class DiffusionPriorTrainer(nn.Module):
**kwargs **kwargs
) )
if exists(cosine_decay_max_steps):
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
else:
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0) self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
@@ -271,6 +275,7 @@ class DiffusionPriorTrainer(nn.Module):
# FIXME: LambdaLR can't be saved due to pickling issues # FIXME: LambdaLR can't be saved due to pickling issues
save_obj = dict( save_obj = dict(
optimizer = self.optimizer.state_dict(), optimizer = self.optimizer.state_dict(),
scheduler = self.scheduler.state_dict(),
warmup_scheduler = self.warmup_scheduler, warmup_scheduler = self.warmup_scheduler,
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(), model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
version = version.parse(__version__), version = version.parse(__version__),
@@ -317,7 +322,9 @@ class DiffusionPriorTrainer(nn.Module):
# unwrap the model when loading from checkpoint # unwrap the model when loading from checkpoint
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict) self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device)) self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
self.optimizer.load_state_dict(loaded_obj['optimizer']) self.optimizer.load_state_dict(loaded_obj['optimizer'])
self.scheduler.load_state_dict(loaded_obj['scheduler'])
# set warmupstep # set warmupstep
if exists(self.warmup_scheduler): if exists(self.warmup_scheduler):

View File

@@ -1 +1 @@
__version__ = '1.8.0' __version__ = '1.8.1'