allow for saving of additional fields on save method in trainers, and return loaded objects from the load method

This commit is contained in:
Phil Wang
2022-05-22 22:14:25 -07:00
parent 4d346e98d9
commit ae42d03006
3 changed files with 14 additions and 8 deletions

View File

@@ -288,7 +288,7 @@ class DiffusionPriorTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0]))
def save(self, path, overwrite = True):
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
@@ -298,7 +298,8 @@ class DiffusionPriorTrainer(nn.Module):
optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(),
version = get_pkg_version(),
step = self.step.item()
step = self.step.item(),
**kwargs
)
if self.use_ema:
@@ -319,7 +320,7 @@ class DiffusionPriorTrainer(nn.Module):
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return
return loaded_obj
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer'])
@@ -328,6 +329,8 @@ class DiffusionPriorTrainer(nn.Module):
assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
return loaded_obj
def update(self):
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
@@ -449,7 +452,7 @@ class DecoderTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0.]))
def save(self, path, overwrite = True):
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)
@@ -457,7 +460,8 @@ class DecoderTrainer(nn.Module):
save_obj = dict(
model = self.decoder.state_dict(),
version = get_pkg_version(),
step = self.step.item()
step = self.step.item(),
**kwargs
)
for ind in range(0, self.num_unets):
@@ -485,7 +489,7 @@ class DecoderTrainer(nn.Module):
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return
return loaded_obj
for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
@@ -500,6 +504,8 @@ class DecoderTrainer(nn.Module):
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
return loaded_obj
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])