diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index f154512..32ca587 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -182,6 +182,7 @@ class DiffusionPriorTrainer(nn.Module): max_grad_norm = None, group_wd_params = True, warmup_steps = 1, + cosine_decay_max_steps = None, **kwargs ): super().__init__() @@ -233,8 +234,11 @@ class DiffusionPriorTrainer(nn.Module): **self.optim_kwargs, **kwargs ) - - self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0) + + if exists(cosine_decay_max_steps): + self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps) + else: + self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0) self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None @@ -271,6 +275,7 @@ class DiffusionPriorTrainer(nn.Module): # FIXME: LambdaLR can't be saved due to pickling issues save_obj = dict( optimizer = self.optimizer.state_dict(), + scheduler = self.scheduler.state_dict(), warmup_scheduler = self.warmup_scheduler, model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(), version = version.parse(__version__), @@ -317,7 +322,9 @@ class DiffusionPriorTrainer(nn.Module): # unwrap the model when loading from checkpoint self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict) self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device)) + self.optimizer.load_state_dict(loaded_obj['optimizer']) + self.scheduler.load_state_dict(loaded_obj['scheduler']) # set warmupstep if exists(self.warmup_scheduler): diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index b280975..e8b6b09 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.8.0' +__version__ = '1.8.1'