diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index a6b6bbc..b20b206 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -236,7 +236,7 @@ class DiffusionPriorTrainer(nn.Module): ) if exists(cosine_decay_max_steps): - self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps) + self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps) else: self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)