diff --git a/configs/README.md b/configs/README.md index e4fb77d..d473495 100644 --- a/configs/README.md +++ b/configs/README.md @@ -74,9 +74,6 @@ Settings for controlling the training hyperparameters. | `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. | | `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. | | `ema_beta` | No | `0.99` | The ema coefficient. | -| `save_all` | No | `False` | If True, preserves a checkpoint for every epoch. | -| `save_latest` | No | `True` | If True, overwrites the `latest.pth` every time the model is saved. | -| `save_best` | No | `True` | If True, overwrites the `best.pth` every time the model has a lower validation loss than all previous models. | | `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. | **Evaluate:** @@ -163,9 +160,10 @@ All save locations have these configuration options | Option | Required | Default | Description | | ------ | -------- | ------- | ----------- | | `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. | -| `save_latest_to` | No | `latest.pth` | Sets the relative path to save the latest model to. | -| `save_best_to` | No | `best.pth` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. | -| `save_type` | No | `'checkpoint'` | The type of save. `'checkpoint'` saves a checkpoint, `'model'` saves a model without any fluff (Saves with ema if ema is enabled). | +| `save_latest_to` | No | `None` | Sets the relative path to save the latest model to. | +| `save_best_to` | No | `None` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. | +| `save_meta_to` | No | `None` | The path to save metadata files in. This includes the config files used to start the training. | +| `save_type` | No | `checkpoint` | The type of save. `checkpoint` saves a checkpoint, `model` saves a model without any fluff (Saves with ema if ema is enabled). | If using `local` | Option | Required | Default | Description | @@ -177,7 +175,6 @@ If using `huggingface` | ------ | -------- | ------- | ----------- | | `save_to` | Yes | N/A | Must be `huggingface`. | | `huggingface_repo` | Yes | N/A | The huggingface repository to save to. | -| `huggingface_base_path` | Yes | N/A | The base path that checkpoints will be saved under. | | `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. | If using `wandb` diff --git a/configs/train_decoder_config.example.json b/configs/train_decoder_config.example.json index cebdb02..22658cc 100644 --- a/configs/train_decoder_config.example.json +++ b/configs/train_decoder_config.example.json @@ -56,9 +56,6 @@ "use_ema": true, "ema_beta": 0.99, "amp": false, - "save_all": false, - "save_latest": true, - "save_best": true, "unet_training_mask": [true] }, "evaluate": { @@ -96,14 +93,15 @@ }, "save": [{ - "save_to": "wandb" + "save_to": "wandb", + "save_latest_to": "latest.pth" }, { "save_to": "huggingface", "huggingface_repo": "Veldrovive/test_model", - "save_all": true, - "save_latest": true, - "save_best": true, + "save_latest_to": "path/to/model_dir/latest.pth", + "save_best_to": "path/to/model_dir/best.pth", + "save_meta_to": "path/to/directory/for/assorted/files", "save_type": "model" }] diff --git a/configs/train_decoder_config.test.json b/configs/train_decoder_config.test.json index 26e1c43..101846e 100644 --- a/configs/train_decoder_config.test.json +++ b/configs/train_decoder_config.test.json @@ -61,9 +61,6 @@ "use_ema": true, "ema_beta": 0.99, "amp": false, - "save_all": false, - "save_latest": true, - "save_best": true, "unet_training_mask": [true] }, "evaluate": { @@ -96,7 +93,8 @@ }, "save": [{ - "save_to": "local" + "save_to": "local", + "save_latest_to": "latest.pth" }] } } diff --git a/dalle2_pytorch/trackers.py b/dalle2_pytorch/trackers.py index 2d0ba08..057fbb9 100644 --- a/dalle2_pytorch/trackers.py +++ b/dalle2_pytorch/trackers.py @@ -4,13 +4,15 @@ import json from pathlib import Path import shutil from itertools import zip_longest -from typing import Optional, List, Union +from typing import Any, Optional, List, Union from pydantic import BaseModel import torch - +from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior from dalle2_pytorch.utils import import_or_print_error from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer +from dalle2_pytorch.version import __version__ +from packaging import version # constants @@ -21,16 +23,6 @@ DEFAULT_DATA_PATH = './.tracker-data' def exists(val): return val is not None -# load file functions - -def load_wandb_file(run_path, file_path, **kwargs): - wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function') - file_reference = wandb.restore(file_path, run_path=run_path) - return file_reference.name - -def load_local_file(file_path, **kwargs): - return file_path - class BaseLogger: """ An abstract class representing an object that can log data. @@ -234,7 +226,7 @@ class LocalLoader(BaseLoader): def init(self, logger: BaseLogger, **kwargs) -> None: # Makes sure the file exists to be loaded - if not self.file_path.exists(): + if not self.file_path.exists() and not self.only_auto_resume: raise FileNotFoundError(f'Model not found at {self.file_path}') def recall(self) -> dict: @@ -283,9 +275,9 @@ def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader: class BaseSaver: def __init__(self, data_path: str, - save_latest_to: Optional[Union[str, bool]] = 'latest.pth', - save_best_to: Optional[Union[str, bool]] = 'best.pth', - save_meta_to: str = './', + save_latest_to: Optional[Union[str, bool]] = None, + save_best_to: Optional[Union[str, bool]] = None, + save_meta_to: Optional[str] = None, save_type: str = 'checkpoint', **kwargs ): @@ -295,10 +287,10 @@ class BaseSaver: self.save_best_to = save_best_to self.saving_best = save_best_to is not None and save_best_to is not False self.save_meta_to = save_meta_to + self.saving_meta = save_meta_to is not None self.save_type = save_type assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`' - assert self.save_meta_to is not None, '`save_meta_to` must be provided' - assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided' + assert self.saving_latest or self.saving_best or self.saving_meta, 'At least one saving option must be specified' def init(self, logger: BaseLogger, **kwargs) -> None: raise NotImplementedError @@ -459,6 +451,11 @@ class Tracker: print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n') print(f"New logger config: {self.logger.__dict__}") + self.save_metadata = dict( + version = version.parse(__version__) + ) # Data that will be saved alongside the checkpoint or model + self.blacklisted_checkpoint_metadata_keys = ['scaler', 'optimizer', 'model', 'version', 'step', 'steps'] # These keys would cause us to error if we try to save them as metadata + assert self.logger is not None, '`logger` must be set before `init` is called' if self.dummy_mode: # The only thing we need is a loader @@ -507,8 +504,15 @@ class Tracker: # Save the config under config_name in the root folder of data_path shutil.copy(current_config_path, self.data_path / config_name) for saver in self.savers: - remote_path = Path(saver.save_meta_to) / config_name - saver.save_file(current_config_path, str(remote_path)) + if saver.saving_meta: + remote_path = Path(saver.save_meta_to) / config_name + saver.save_file(current_config_path, str(remote_path)) + + def add_save_metadata(self, state_dict_key: str, metadata: Any): + """ + Adds a new piece of metadata that will be saved along with the model or decoder. + """ + self.save_metadata[state_dict_key] = metadata def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path: """ @@ -518,24 +522,34 @@ class Tracker: """ assert save_type in ['checkpoint', 'model'] if save_type == 'checkpoint': - trainer.save(file_path, overwrite=True, **kwargs) + # Create a metadata dict without the blacklisted keys so we do not error when we create the state dict + metadata = {k: v for k, v in self.save_metadata.items() if k not in self.blacklisted_checkpoint_metadata_keys} + trainer.save(file_path, overwrite=True, **kwargs, **metadata) elif save_type == 'model': if isinstance(trainer, DiffusionPriorTrainer): prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior - state_dict = trainer.unwrap_model(prior).state_dict() - torch.save(state_dict, file_path) + prior: DiffusionPrior = trainer.unwrap_model(prior) + # Remove CLIP if it is part of the model + prior.clip = None + model_state_dict = prior.state_dict() elif isinstance(trainer, DecoderTrainer): - decoder = trainer.accelerator.unwrap_model(trainer.decoder) + decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder) + # Remove CLIP if it is part of the model + decoder.clip = None if trainer.use_ema: trainable_unets = decoder.unets decoder.unets = trainer.unets # Swap EMA unets in - state_dict = decoder.state_dict() + model_state_dict = decoder.state_dict() decoder.unets = trainable_unets # Swap back else: - state_dict = decoder.state_dict() - torch.save(state_dict, file_path) + model_state_dict = decoder.state_dict() else: raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?') + state_dict = { + **self.save_metadata, + 'model': model_state_dict + } + torch.save(state_dict, file_path) return Path(file_path) def save(self, trainer, is_best: bool, is_latest: bool, **kwargs): diff --git a/train_decoder.py b/train_decoder.py index 22ff816..3c41df7 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -513,6 +513,7 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_ } tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy) tracker.save_config(config_path, config_name='decoder_config.json') + tracker.add_save_metadata(state_dict_key='config', metadata=config.dict()) return tracker def initialize_training(config: TrainDecoderConfig, config_path):