fix decoder trainer optimizer loading (since there are multiple for each unet), also save and load step number correctly

This commit is contained in:
Phil Wang
2022-05-22 15:21:00 -07:00
parent 271a376eaf
commit 49de72040c
2 changed files with 24 additions and 9 deletions

View File

@@ -196,7 +196,7 @@ class EMA(nn.Module):
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))
self.register_buffer('step', torch.tensor([0]))
def restore_ema_model_device(self):
device = self.initted.device
@@ -292,7 +292,7 @@ class DiffusionPriorTrainer(nn.Module):
self.max_grad_norm = max_grad_norm
self.register_buffer('step', torch.tensor([0.]))
self.register_buffer('step', torch.tensor([0]))
def save(self, path, overwrite = True):
path = Path(path)
@@ -303,7 +303,8 @@ class DiffusionPriorTrainer(nn.Module):
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(),
version = get_pkg_version()
version = get_pkg_version(),
step = self.step.item()
)
if self.use_ema:
@@ -321,6 +322,7 @@ class DiffusionPriorTrainer(nn.Module):
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}')
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return
@@ -459,12 +461,18 @@ class DecoderTrainer(nn.Module):
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.decoder.state_dict(),
version = get_pkg_version()
version = get_pkg_version(),
step = self.step.item()
)
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer = getattr(self, optimizer_key)
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
@@ -480,12 +488,19 @@ class DecoderTrainer(nn.Module):
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer = getattr(self, optimizer_key)
scaler.load_state_dict(loaded_obj[scaler_key])
optimizer.load_state_dict(loaded_obj[optimizer_key])
if self.use_ema:
assert 'ema' in loaded_obj