mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
make sure diffusion prior trainer can operate with no warmup
This commit is contained in:
@@ -181,7 +181,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
eps = 1e-6,
|
||||
max_grad_norm = None,
|
||||
group_wd_params = True,
|
||||
warmup_steps = 1,
|
||||
warmup_steps = None,
|
||||
cosine_decay_max_steps = None,
|
||||
**kwargs
|
||||
):
|
||||
@@ -357,7 +357,8 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
||||
if not self.accelerator.optimizer_step_was_skipped:
|
||||
with self.warmup_scheduler.dampening():
|
||||
sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
|
||||
with sched_context():
|
||||
self.scheduler.step()
|
||||
|
||||
if self.use_ema:
|
||||
|
||||
Reference in New Issue
Block a user