mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
allow for saving of additional fields on save method in trainers, and return loaded objects from the load method
This commit is contained in:
@@ -1712,7 +1712,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self.unconditional = unconditional
|
self.unconditional = unconditional
|
||||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||||
|
|
||||||
assert self.unconditional or (exists(clip) ^ (exists(image_size) or exists(image_sizes))), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
assert self.unconditional or (exists(clip) or exists(image_size) or exists(image_sizes)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||||
|
|
||||||
self.clip = None
|
self.clip = None
|
||||||
if exists(clip):
|
if exists(clip):
|
||||||
|
|||||||
@@ -288,7 +288,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0]))
|
self.register_buffer('step', torch.tensor([0]))
|
||||||
|
|
||||||
def save(self, path, overwrite = True):
|
def save(self, path, overwrite = True, **kwargs):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
assert not (path.exists() and not overwrite)
|
assert not (path.exists() and not overwrite)
|
||||||
path.parent.mkdir(parents = True, exist_ok = True)
|
path.parent.mkdir(parents = True, exist_ok = True)
|
||||||
@@ -298,7 +298,8 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
optimizer = self.optimizer.state_dict(),
|
optimizer = self.optimizer.state_dict(),
|
||||||
model = self.diffusion_prior.state_dict(),
|
model = self.diffusion_prior.state_dict(),
|
||||||
version = get_pkg_version(),
|
version = get_pkg_version(),
|
||||||
step = self.step.item()
|
step = self.step.item(),
|
||||||
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
@@ -319,7 +320,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
|
|
||||||
if only_model:
|
if only_model:
|
||||||
return
|
return loaded_obj
|
||||||
|
|
||||||
self.scaler.load_state_dict(loaded_obj['scaler'])
|
self.scaler.load_state_dict(loaded_obj['scaler'])
|
||||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||||
@@ -328,6 +329,8 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
assert 'ema' in loaded_obj
|
assert 'ema' in loaded_obj
|
||||||
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||||
|
|
||||||
|
return loaded_obj
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
if exists(self.max_grad_norm):
|
if exists(self.max_grad_norm):
|
||||||
self.scaler.unscale_(self.optimizer)
|
self.scaler.unscale_(self.optimizer)
|
||||||
@@ -449,7 +452,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
|
|
||||||
def save(self, path, overwrite = True):
|
def save(self, path, overwrite = True, **kwargs):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
assert not (path.exists() and not overwrite)
|
assert not (path.exists() and not overwrite)
|
||||||
path.parent.mkdir(parents = True, exist_ok = True)
|
path.parent.mkdir(parents = True, exist_ok = True)
|
||||||
@@ -457,7 +460,8 @@ class DecoderTrainer(nn.Module):
|
|||||||
save_obj = dict(
|
save_obj = dict(
|
||||||
model = self.decoder.state_dict(),
|
model = self.decoder.state_dict(),
|
||||||
version = get_pkg_version(),
|
version = get_pkg_version(),
|
||||||
step = self.step.item()
|
step = self.step.item(),
|
||||||
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
for ind in range(0, self.num_unets):
|
for ind in range(0, self.num_unets):
|
||||||
@@ -485,7 +489,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
|
|
||||||
if only_model:
|
if only_model:
|
||||||
return
|
return loaded_obj
|
||||||
|
|
||||||
for ind in range(0, self.num_unets):
|
for ind in range(0, self.num_unets):
|
||||||
scaler_key = f'scaler{ind}'
|
scaler_key = f'scaler{ind}'
|
||||||
@@ -500,6 +504,8 @@ class DecoderTrainer(nn.Module):
|
|||||||
assert 'ema' in loaded_obj
|
assert 'ema' in loaded_obj
|
||||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||||
|
|
||||||
|
return loaded_obj
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unets(self):
|
def unets(self):
|
||||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||||
|
|||||||
Reference in New Issue
Block a user