mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
allow for saving of additional fields on save method in trainers, and return loaded objects from the load method
This commit is contained in:
@@ -1712,7 +1712,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.unconditional = unconditional
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
|
||||
assert self.unconditional or (exists(clip) ^ (exists(image_size) or exists(image_sizes))), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
assert self.unconditional or (exists(clip) or exists(image_size) or exists(image_sizes)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
|
||||
self.clip = None
|
||||
if exists(clip):
|
||||
|
||||
@@ -288,7 +288,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
|
||||
def save(self, path, overwrite = True):
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
@@ -298,7 +298,8 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
optimizer = self.optimizer.state_dict(),
|
||||
model = self.diffusion_prior.state_dict(),
|
||||
version = get_pkg_version(),
|
||||
step = self.step.item()
|
||||
step = self.step.item(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if self.use_ema:
|
||||
@@ -319,7 +320,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
if only_model:
|
||||
return
|
||||
return loaded_obj
|
||||
|
||||
self.scaler.load_state_dict(loaded_obj['scaler'])
|
||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
@@ -328,6 +329,8 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
|
||||
return loaded_obj
|
||||
|
||||
def update(self):
|
||||
if exists(self.max_grad_norm):
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
@@ -449,7 +452,7 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def save(self, path, overwrite = True):
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
@@ -457,7 +460,8 @@ class DecoderTrainer(nn.Module):
|
||||
save_obj = dict(
|
||||
model = self.decoder.state_dict(),
|
||||
version = get_pkg_version(),
|
||||
step = self.step.item()
|
||||
step = self.step.item(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
@@ -485,7 +489,7 @@ class DecoderTrainer(nn.Module):
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
if only_model:
|
||||
return
|
||||
return loaded_obj
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
scaler_key = f'scaler{ind}'
|
||||
@@ -500,6 +504,8 @@ class DecoderTrainer(nn.Module):
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
|
||||
return loaded_obj
|
||||
|
||||
@property
|
||||
def unets(self):
|
||||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||
|
||||
Reference in New Issue
Block a user