mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-17 10:14:45 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27f19ba7fa | ||
|
|
8f38339c2b |
@@ -181,7 +181,8 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
eps = 1e-6,
|
eps = 1e-6,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
group_wd_params = True,
|
group_wd_params = True,
|
||||||
warmup_steps = 1,
|
warmup_steps = None,
|
||||||
|
cosine_decay_max_steps = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -233,8 +234,11 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
**self.optim_kwargs,
|
**self.optim_kwargs,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
if exists(cosine_decay_max_steps):
|
||||||
|
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
|
||||||
|
else:
|
||||||
|
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
||||||
|
|
||||||
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
|
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
|
||||||
|
|
||||||
@@ -271,6 +275,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
# FIXME: LambdaLR can't be saved due to pickling issues
|
# FIXME: LambdaLR can't be saved due to pickling issues
|
||||||
save_obj = dict(
|
save_obj = dict(
|
||||||
optimizer = self.optimizer.state_dict(),
|
optimizer = self.optimizer.state_dict(),
|
||||||
|
scheduler = self.scheduler.state_dict(),
|
||||||
warmup_scheduler = self.warmup_scheduler,
|
warmup_scheduler = self.warmup_scheduler,
|
||||||
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
|
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
|
||||||
version = version.parse(__version__),
|
version = version.parse(__version__),
|
||||||
@@ -317,7 +322,9 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
# unwrap the model when loading from checkpoint
|
# unwrap the model when loading from checkpoint
|
||||||
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
|
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
|
||||||
|
|
||||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||||
|
self.scheduler.load_state_dict(loaded_obj['scheduler'])
|
||||||
|
|
||||||
# set warmupstep
|
# set warmupstep
|
||||||
if exists(self.warmup_scheduler):
|
if exists(self.warmup_scheduler):
|
||||||
@@ -350,7 +357,8 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
||||||
if not self.accelerator.optimizer_step_was_skipped:
|
if not self.accelerator.optimizer_step_was_skipped:
|
||||||
with self.warmup_scheduler.dampening():
|
sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
|
||||||
|
with sched_context():
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.8.0'
|
__version__ = '1.8.2'
|
||||||
|
|||||||
Reference in New Issue
Block a user