Quality of life improvements for tracker savers (#210)

The default save location is now none so if keys are not specified the
corresponding checkpoint type is not saved.

Models and checkpoints are now both saved with version number and the
config used to create them in order to simplify loading.

Documentation was fixed to be in line with current usage.
This commit is contained in:
Aidan Dempster
2022-07-19 20:50:18 -04:00
committed by GitHub
parent 291377bb9c
commit f97e55ec6b
5 changed files with 53 additions and 45 deletions

View File

@@ -74,9 +74,6 @@ Settings for controlling the training hyperparameters.
| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
| `ema_beta` | No | `0.99` | The ema coefficient. |
| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. |
| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. |
| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. |
| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
**<ins>Evaluate</ins>:**
@@ -163,9 +160,10 @@ All save locations have these configuration options
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
| `save_latest_to` | No | `latest.pth` | Sets the relative path to save the latest model to. |
| `save_best_to` | No | `best.pth` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
| `save_type` | No | `'checkpoint'` | The type of save. `'checkpoint'` saves a checkpoint, `'model'` saves a model without any fluff (Saves with ema if ema is enabled). |
| `save_latest_to` | No | `None` | Sets the relative path to save the latest model to. |
| `save_best_to` | No | `None` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
| `save_meta_to` | No | `None` | The path to save metadata files in. This includes the config files used to start the training. |
| `save_type` | No | `checkpoint` | The type of save. `checkpoint` saves a checkpoint, `model` saves a model without any fluff (Saves with ema if ema is enabled). |
If using `local`
| Option | Required | Default | Description |
@@ -177,7 +175,6 @@ If using `huggingface`
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `huggingface`. |
| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
| `huggingface_base_path` | Yes | N/A | The base path that checkpoints will be saved under. |
| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
If using `wandb`

View File

@@ -56,9 +56,6 @@
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
@@ -96,14 +93,15 @@
},
"save": [{
"save_to": "wandb"
"save_to": "wandb",
"save_latest_to": "latest.pth"
}, {
"save_to": "huggingface",
"huggingface_repo": "Veldrovive/test_model",
"save_all": true,
"save_latest": true,
"save_best": true,
"save_latest_to": "path/to/model_dir/latest.pth",
"save_best_to": "path/to/model_dir/best.pth",
"save_meta_to": "path/to/directory/for/assorted/files",
"save_type": "model"
}]

View File

@@ -61,9 +61,6 @@
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
@@ -96,7 +93,8 @@
},
"save": [{
"save_to": "local"
"save_to": "local",
"save_latest_to": "latest.pth"
}]
}
}

View File

@@ -4,13 +4,15 @@ import json
from pathlib import Path
import shutil
from itertools import zip_longest
from typing import Optional, List, Union
from typing import Any, Optional, List, Union
from pydantic import BaseModel
import torch
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.utils import import_or_print_error
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.version import __version__
from packaging import version
# constants
@@ -21,16 +23,6 @@ DEFAULT_DATA_PATH = './.tracker-data'
def exists(val):
return val is not None
# load file functions
def load_wandb_file(run_path, file_path, **kwargs):
wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
file_reference = wandb.restore(file_path, run_path=run_path)
return file_reference.name
def load_local_file(file_path, **kwargs):
return file_path
class BaseLogger:
"""
An abstract class representing an object that can log data.
@@ -234,7 +226,7 @@ class LocalLoader(BaseLoader):
def init(self, logger: BaseLogger, **kwargs) -> None:
# Makes sure the file exists to be loaded
if not self.file_path.exists():
if not self.file_path.exists() and not self.only_auto_resume:
raise FileNotFoundError(f'Model not found at {self.file_path}')
def recall(self) -> dict:
@@ -283,9 +275,9 @@ def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
class BaseSaver:
def __init__(self,
data_path: str,
save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
save_best_to: Optional[Union[str, bool]] = 'best.pth',
save_meta_to: str = './',
save_latest_to: Optional[Union[str, bool]] = None,
save_best_to: Optional[Union[str, bool]] = None,
save_meta_to: Optional[str] = None,
save_type: str = 'checkpoint',
**kwargs
):
@@ -295,10 +287,10 @@ class BaseSaver:
self.save_best_to = save_best_to
self.saving_best = save_best_to is not None and save_best_to is not False
self.save_meta_to = save_meta_to
self.saving_meta = save_meta_to is not None
self.save_type = save_type
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
assert self.save_meta_to is not None, '`save_meta_to` must be provided'
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
assert self.saving_latest or self.saving_best or self.saving_meta, 'At least one saving option must be specified'
def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError
@@ -459,6 +451,11 @@ class Tracker:
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
print(f"New logger config: {self.logger.__dict__}")
self.save_metadata = dict(
version = version.parse(__version__)
) # Data that will be saved alongside the checkpoint or model
self.blacklisted_checkpoint_metadata_keys = ['scaler', 'optimizer', 'model', 'version', 'step', 'steps'] # These keys would cause us to error if we try to save them as metadata
assert self.logger is not None, '`logger` must be set before `init` is called'
if self.dummy_mode:
# The only thing we need is a loader
@@ -507,9 +504,16 @@ class Tracker:
# Save the config under config_name in the root folder of data_path
shutil.copy(current_config_path, self.data_path / config_name)
for saver in self.savers:
if saver.saving_meta:
remote_path = Path(saver.save_meta_to) / config_name
saver.save_file(current_config_path, str(remote_path))
def add_save_metadata(self, state_dict_key: str, metadata: Any):
"""
Adds a new piece of metadata that will be saved along with the model or decoder.
"""
self.save_metadata[state_dict_key] = metadata
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
"""
Gets the state dict to be saved and writes it to file_path.
@@ -518,24 +522,34 @@ class Tracker:
"""
assert save_type in ['checkpoint', 'model']
if save_type == 'checkpoint':
trainer.save(file_path, overwrite=True, **kwargs)
# Create a metadata dict without the blacklisted keys so we do not error when we create the state dict
metadata = {k: v for k, v in self.save_metadata.items() if k not in self.blacklisted_checkpoint_metadata_keys}
trainer.save(file_path, overwrite=True, **kwargs, **metadata)
elif save_type == 'model':
if isinstance(trainer, DiffusionPriorTrainer):
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
state_dict = trainer.unwrap_model(prior).state_dict()
torch.save(state_dict, file_path)
prior: DiffusionPrior = trainer.unwrap_model(prior)
# Remove CLIP if it is part of the model
prior.clip = None
model_state_dict = prior.state_dict()
elif isinstance(trainer, DecoderTrainer):
decoder = trainer.accelerator.unwrap_model(trainer.decoder)
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
# Remove CLIP if it is part of the model
decoder.clip = None
if trainer.use_ema:
trainable_unets = decoder.unets
decoder.unets = trainer.unets # Swap EMA unets in
state_dict = decoder.state_dict()
model_state_dict = decoder.state_dict()
decoder.unets = trainable_unets # Swap back
else:
state_dict = decoder.state_dict()
torch.save(state_dict, file_path)
model_state_dict = decoder.state_dict()
else:
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
state_dict = {
**self.save_metadata,
'model': model_state_dict
}
torch.save(state_dict, file_path)
return Path(file_path)
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):

View File

@@ -513,6 +513,7 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
}
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
tracker.save_config(config_path, config_name='decoder_config.json')
tracker.add_save_metadata(state_dict_key='config', metadata=config.dict())
return tracker
def initialize_training(config: TrainDecoderConfig, config_path):