mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Quality of life improvements for tracker savers (#210)
The default save location is now none so if keys are not specified the corresponding checkpoint type is not saved. Models and checkpoints are now both saved with version number and the config used to create them in order to simplify loading. Documentation was fixed to be in line with current usage.
This commit is contained in:
@@ -4,13 +4,15 @@ import json
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from itertools import zip_longest
|
||||
from typing import Optional, List, Union
|
||||
from typing import Any, Optional, List, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
import torch
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
from dalle2_pytorch.utils import import_or_print_error
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||
from dalle2_pytorch.version import __version__
|
||||
from packaging import version
|
||||
|
||||
# constants
|
||||
|
||||
@@ -21,16 +23,6 @@ DEFAULT_DATA_PATH = './.tracker-data'
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
# load file functions
|
||||
|
||||
def load_wandb_file(run_path, file_path, **kwargs):
|
||||
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
|
||||
file_reference = wandb.restore(file_path, run_path=run_path)
|
||||
return file_reference.name
|
||||
|
||||
def load_local_file(file_path, **kwargs):
|
||||
return file_path
|
||||
|
||||
class BaseLogger:
|
||||
"""
|
||||
An abstract class representing an object that can log data.
|
||||
@@ -234,7 +226,7 @@ class LocalLoader(BaseLoader):
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
# Makes sure the file exists to be loaded
|
||||
if not self.file_path.exists():
|
||||
if not self.file_path.exists() and not self.only_auto_resume:
|
||||
raise FileNotFoundError(f'Model not found at {self.file_path}')
|
||||
|
||||
def recall(self) -> dict:
|
||||
@@ -283,9 +275,9 @@ def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
|
||||
class BaseSaver:
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
|
||||
save_best_to: Optional[Union[str, bool]] = 'best.pth',
|
||||
save_meta_to: str = './',
|
||||
save_latest_to: Optional[Union[str, bool]] = None,
|
||||
save_best_to: Optional[Union[str, bool]] = None,
|
||||
save_meta_to: Optional[str] = None,
|
||||
save_type: str = 'checkpoint',
|
||||
**kwargs
|
||||
):
|
||||
@@ -295,10 +287,10 @@ class BaseSaver:
|
||||
self.save_best_to = save_best_to
|
||||
self.saving_best = save_best_to is not None and save_best_to is not False
|
||||
self.save_meta_to = save_meta_to
|
||||
self.saving_meta = save_meta_to is not None
|
||||
self.save_type = save_type
|
||||
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
|
||||
assert self.save_meta_to is not None, '`save_meta_to` must be provided'
|
||||
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
|
||||
assert self.saving_latest or self.saving_best or self.saving_meta, 'At least one saving option must be specified'
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
@@ -459,6 +451,11 @@ class Tracker:
|
||||
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
|
||||
print(f"New logger config: {self.logger.__dict__}")
|
||||
|
||||
self.save_metadata = dict(
|
||||
version = version.parse(__version__)
|
||||
) # Data that will be saved alongside the checkpoint or model
|
||||
self.blacklisted_checkpoint_metadata_keys = ['scaler', 'optimizer', 'model', 'version', 'step', 'steps'] # These keys would cause us to error if we try to save them as metadata
|
||||
|
||||
assert self.logger is not None, '`logger` must be set before `init` is called'
|
||||
if self.dummy_mode:
|
||||
# The only thing we need is a loader
|
||||
@@ -507,8 +504,15 @@ class Tracker:
|
||||
# Save the config under config_name in the root folder of data_path
|
||||
shutil.copy(current_config_path, self.data_path / config_name)
|
||||
for saver in self.savers:
|
||||
remote_path = Path(saver.save_meta_to) / config_name
|
||||
saver.save_file(current_config_path, str(remote_path))
|
||||
if saver.saving_meta:
|
||||
remote_path = Path(saver.save_meta_to) / config_name
|
||||
saver.save_file(current_config_path, str(remote_path))
|
||||
|
||||
def add_save_metadata(self, state_dict_key: str, metadata: Any):
|
||||
"""
|
||||
Adds a new piece of metadata that will be saved along with the model or decoder.
|
||||
"""
|
||||
self.save_metadata[state_dict_key] = metadata
|
||||
|
||||
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
|
||||
"""
|
||||
@@ -518,24 +522,34 @@ class Tracker:
|
||||
"""
|
||||
assert save_type in ['checkpoint', 'model']
|
||||
if save_type == 'checkpoint':
|
||||
trainer.save(file_path, overwrite=True, **kwargs)
|
||||
# Create a metadata dict without the blacklisted keys so we do not error when we create the state dict
|
||||
metadata = {k: v for k, v in self.save_metadata.items() if k not in self.blacklisted_checkpoint_metadata_keys}
|
||||
trainer.save(file_path, overwrite=True, **kwargs, **metadata)
|
||||
elif save_type == 'model':
|
||||
if isinstance(trainer, DiffusionPriorTrainer):
|
||||
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
||||
state_dict = trainer.unwrap_model(prior).state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
prior: DiffusionPrior = trainer.unwrap_model(prior)
|
||||
# Remove CLIP if it is part of the model
|
||||
prior.clip = None
|
||||
model_state_dict = prior.state_dict()
|
||||
elif isinstance(trainer, DecoderTrainer):
|
||||
decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
||||
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
||||
# Remove CLIP if it is part of the model
|
||||
decoder.clip = None
|
||||
if trainer.use_ema:
|
||||
trainable_unets = decoder.unets
|
||||
decoder.unets = trainer.unets # Swap EMA unets in
|
||||
state_dict = decoder.state_dict()
|
||||
model_state_dict = decoder.state_dict()
|
||||
decoder.unets = trainable_unets # Swap back
|
||||
else:
|
||||
state_dict = decoder.state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
model_state_dict = decoder.state_dict()
|
||||
else:
|
||||
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
|
||||
state_dict = {
|
||||
**self.save_metadata,
|
||||
'model': model_state_dict
|
||||
}
|
||||
torch.save(state_dict, file_path)
|
||||
return Path(file_path)
|
||||
|
||||
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user