mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-20 19:04:43 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27f19ba7fa |
@@ -181,7 +181,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
eps = 1e-6,
|
eps = 1e-6,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
group_wd_params = True,
|
group_wd_params = True,
|
||||||
warmup_steps = 1,
|
warmup_steps = None,
|
||||||
cosine_decay_max_steps = None,
|
cosine_decay_max_steps = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@@ -357,7 +357,8 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
||||||
if not self.accelerator.optimizer_step_was_skipped:
|
if not self.accelerator.optimizer_step_was_skipped:
|
||||||
with self.warmup_scheduler.dampening():
|
sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
|
||||||
|
with sched_context():
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.8.1'
|
__version__ = '1.8.2'
|
||||||
|
|||||||
Reference in New Issue
Block a user