allow for saving of additional fields on save method in trainers, and return loaded objects from the load method

This commit is contained in:
Phil Wang
2022-05-22 22:14:25 -07:00
parent 4d346e98d9
commit ae42d03006
3 changed files with 14 additions and 8 deletions

View File

@@ -1712,7 +1712,7 @@ class Decoder(BaseGaussianDiffusion):
self.unconditional = unconditional self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present' assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert self.unconditional or (exists(clip) ^ (exists(image_size) or exists(image_sizes))), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)' assert self.unconditional or (exists(clip) or exists(image_size) or exists(image_sizes)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = None self.clip = None
if exists(clip): if exists(clip):

View File

@@ -288,7 +288,7 @@ class DiffusionPriorTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0])) self.register_buffer('step', torch.tensor([0]))
def save(self, path, overwrite = True): def save(self, path, overwrite = True, **kwargs):
path = Path(path) path = Path(path)
assert not (path.exists() and not overwrite) assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True) path.parent.mkdir(parents = True, exist_ok = True)
@@ -298,7 +298,8 @@ class DiffusionPriorTrainer(nn.Module):
optimizer = self.optimizer.state_dict(), optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(), model = self.diffusion_prior.state_dict(),
version = get_pkg_version(), version = get_pkg_version(),
step = self.step.item() step = self.step.item(),
**kwargs
) )
if self.use_ema: if self.use_ema:
@@ -319,7 +320,7 @@ class DiffusionPriorTrainer(nn.Module):
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model: if only_model:
return return loaded_obj
self.scaler.load_state_dict(loaded_obj['scaler']) self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer']) self.optimizer.load_state_dict(loaded_obj['optimizer'])
@@ -328,6 +329,8 @@ class DiffusionPriorTrainer(nn.Module):
assert 'ema' in loaded_obj assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict) self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
return loaded_obj
def update(self): def update(self):
if exists(self.max_grad_norm): if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer) self.scaler.unscale_(self.optimizer)
@@ -449,7 +452,7 @@ class DecoderTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0.])) self.register_buffer('step', torch.tensor([0.]))
def save(self, path, overwrite = True): def save(self, path, overwrite = True, **kwargs):
path = Path(path) path = Path(path)
assert not (path.exists() and not overwrite) assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True) path.parent.mkdir(parents = True, exist_ok = True)
@@ -457,7 +460,8 @@ class DecoderTrainer(nn.Module):
save_obj = dict( save_obj = dict(
model = self.decoder.state_dict(), model = self.decoder.state_dict(),
version = get_pkg_version(), version = get_pkg_version(),
step = self.step.item() step = self.step.item(),
**kwargs
) )
for ind in range(0, self.num_unets): for ind in range(0, self.num_unets):
@@ -485,7 +489,7 @@ class DecoderTrainer(nn.Module):
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model: if only_model:
return return loaded_obj
for ind in range(0, self.num_unets): for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}' scaler_key = f'scaler{ind}'
@@ -500,6 +504,8 @@ class DecoderTrainer(nn.Module):
assert 'ema' in loaded_obj assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
return loaded_obj
@property @property
def unets(self): def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.4.9', version = '0.4.10',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',