mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
Overhauled the tracker system (#172)
* Overhauled the tracker system Separated the logging and saving capabilities Changed creation to be consistent and initializing behavior to be defined by a class initializer instead of in the training script Added class separation between different types of loaders and savers to make the system more verbose * Changed the saver system to only save the checkpoint once * Added better error handling for saving checkpoints * Fixed an error where wandb would error when passed arbitrary kwargs * Fixed variable naming issues for improved saver Added more logging during long pauses * Fixed which methods need to be dummy to immediatly return Added the ability to set whether you find unused parameters * Added more logging for when a wandb loader fails
This commit is contained in:
@@ -505,12 +505,7 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
self.accelerator.save(save_obj, str(path))
|
||||
|
||||
def load(self, path, only_model = False, strict = True):
|
||||
path = Path(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded_obj = torch.load(str(path), map_location = 'cpu')
|
||||
|
||||
def load_state_dict(self, loaded_obj, only_model = False, strict = True):
|
||||
if version.parse(__version__) != version.parse(loaded_obj['version']):
|
||||
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||
|
||||
@@ -530,6 +525,14 @@ class DecoderTrainer(nn.Module):
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
|
||||
def load(self, path, only_model = False, strict = True):
|
||||
path = Path(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded_obj = torch.load(str(path), map_location = 'cpu')
|
||||
|
||||
self.load_state_dict(loaded_obj, only_model = only_model, strict = strict)
|
||||
|
||||
return loaded_obj
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user