Overhauled the tracker system (#172)

* Overhauled the tracker system
Separated the logging and saving capabilities
Changed creation to be consistent and initializing behavior to be defined by a class initializer instead of in the training script
Added class separation between different types of loaders and savers to make the system more verbose

* Changed the saver system to only save the checkpoint once

* Added better error handling for saving checkpoints

* Fixed an error where wandb would error when passed arbitrary kwargs

* Fixed variable naming issues for improved saver
Added more logging during long pauses

* Fixed which methods need to be dummy to immediatly return
Added the ability to set whether you find unused parameters

* Added more logging for when a wandb loader fails
This commit is contained in:
Aidan Dempster
2022-07-01 12:39:40 -04:00
committed by GitHub
parent 7b0edf9e42
commit 27b0f7ca0d
7 changed files with 662 additions and 212 deletions

View File

@@ -91,21 +91,83 @@ Each metric can be enabled by setting its configuration. The configuration keys
**<ins>Tracker</ins>:**
Selects which tracker to use and configures it.
Selects how the experiment will be tracked.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `tracker_type` | No | `console` | Which tracker to use. Currently accepts `console` or `wandb`. |
| `data_path` | No | `./models` | Where the tracker will store local data. |
| `verbose` | No | `False` | Enables console logging for non-console trackers. |
| `data_path` | No | `./.tracker-data` | The path to the folder where temporary tracker data will be saved. |
| `overwrite_data_path` | No | `False` | If true, the data path will be overwritten. Otherwise, you need to delete it yourself. |
| `log` | Yes | N/A | Logging configuration. |
| `load` | No | `None` | Checkpoint loading configuration. |
| `save` | Yes | N/A | Checkpoint/Model saving configuration. |
Tracking is split up into three sections:
* Log: Where to save run metadata and image output. Options are `console` or `wandb`.
* Load: Where to load a checkpoint from. Options are `local`, `url`, or `wandb`.
* Save: Where to save a checkpoint to. Options are `local`, `huggingface`, or `wandb`.
Other configuration options are required for the specific trackers. To see which are required, reference the initializer parameters of each [tracker](../dalle2_pytorch/trackers.py).
**Logging:**
**<ins>Load</ins>:**
Selects where to load a pretrained model from.
If using `console` there is no further configuration than setting `log_type` to `console`.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `source` | No | `None` | Supports `file` or `wandb`. |
| `resume` | No | `False` | If the tracker support resuming the run, resume it. |
| `log_type` | Yes | N/A | Must be `console`. |
Other configuration options are required for loading from a specific source. To see which are required, reference the load methods at the top of the [tracker file](../dalle2_pytorch/trackers.py).
If using `wandb`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `log_type` | Yes | N/A | Must be `wandb`. |
| `wandb_entity` | Yes | N/A | The wandb entity to log to. |
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
| `wandb_run_name` | No | `None` | The wandb run name. |
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
| `wandb_resume` | No | `False` | Whether to resume an old run. |
**Loading:**
If using `local`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `load_from` | Yes | N/A | Must be `local`. |
| `file_path` | Yes | N/A | The path to the checkpoint file. |
If using `url`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `load_from` | Yes | N/A | Must be `url`. |
| `url` | Yes | N/A | The url of the checkpoint file. |
If using `wandb`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `load_from` | Yes | N/A | Must be `wandb`. |
| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the run that is being resumed. |
| `wandb_file_path` | Yes | N/A | The path to the checkpoint file in the W&B file system. |
**Saving:**
Unlike `log` and `load`, `save` may be an array of options so that you can save to different locations in a run.
All save locations have these configuration options
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
| `save_latest_to` | No | `latest.pth` | Sets the relative path to save the latest model to. |
| `save_best_to` | No | `best.pth` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
| `save_type` | No | `'checkpoint'` | The type of save. `'checkpoint'` saves a checkpoint, `'model'` saves a model without any fluff (Saves with ema if ema is enabled). |
If using `local`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `local`. |
If using `huggingface`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `huggingface`. |
| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
| `huggingface_base_path` | Yes | N/A | The base path that checkpoints will be saved under. |
| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
If using `wandb`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `save_to` | Yes | N/A | Must be `wandb`. |
| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the current run. You will almost always want this to be `None`. |

View File

@@ -80,20 +80,32 @@
}
},
"tracker": {
"tracker_type": "console",
"data_path": "./models",
"overwrite_data_path": true,
"wandb_entity": "",
"wandb_project": "",
"log": {
"log_type": "wandb",
"verbose": false
},
"load": {
"source": null,
"wandb_entity": "your_wandb",
"wandb_project": "your_project",
"run_path": "",
"file_path": "",
"verbose": true
},
"resume": false
"load": {
"load_from": null
},
"save": [{
"save_to": "wandb"
}, {
"save_to": "huggingface",
"huggingface_repo": "Veldrovive/test_model",
"save_all": true,
"save_latest": true,
"save_best": true,
"save_type": "model"
}]
}
}

View File

@@ -1,12 +1,15 @@
import urllib.request
import os
from pathlib import Path
import importlib
import shutil
from itertools import zip_longest
from typing import Optional, List, Union
from pydantic import BaseModel
import torch
from torch import nn
from dalle2_pytorch.utils import import_or_print_error
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
# constants
@@ -27,126 +30,484 @@ def load_wandb_file(run_path, file_path, **kwargs):
def load_local_file(file_path, **kwargs):
return file_path
# base class
class BaseTracker(nn.Module):
def __init__(self, data_path = DEFAULT_DATA_PATH):
super().__init__()
class BaseLogger:
"""
An abstract class representing an object that can log data.
Parameters:
data_path (str): A file path for storing temporary data.
verbose (bool): Whether of not to always print logs to the console.
"""
def __init__(self, data_path: str, verbose: bool = False, **kwargs):
self.data_path = Path(data_path)
self.data_path.mkdir(parents = True, exist_ok = True)
self.verbose = verbose
def init(self, config, **kwargs):
raise NotImplementedError
def log(self, log, **kwargs):
raise NotImplementedError
def log_images(self, images, **kwargs):
raise NotImplementedError
def save_state_dict(self, state_dict, relative_path, **kwargs):
raise NotImplementedError
def recall_state_dict(self, recall_source, *args, **kwargs):
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
"""
Loads a state dict from any source.
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
this should not be linked to any individual tracker.
Initializes the logger.
Errors if the logger is invalid.
"""
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
if recall_source == 'wandb':
return torch.load(load_wandb_file(*args, **kwargs))
elif recall_source == 'local':
return torch.load(load_local_file(*args, **kwargs))
else:
raise ValueError('`recall_source` must be one of `wandb` or `local`')
def save_file(self, file_path, **kwargs):
raise NotImplementedError
def recall_file(self, recall_source, *args, **kwargs):
if recall_source == 'wandb':
return load_wandb_file(*args, **kwargs)
elif recall_source == 'local':
return load_local_file(*args, **kwargs)
else:
raise ValueError('`recall_source` must be one of `wandb` or `local`')
def log(self, log, **kwargs) -> None:
raise NotImplementedError
# Tracker that no-ops all calls except for recall
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
raise NotImplementedError
class DummyTracker(BaseTracker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def log_file(self, file_path, **kwargs) -> None:
raise NotImplementedError
def init(self, config, **kwargs):
pass
def log_error(self, error_string, **kwargs) -> None:
raise NotImplementedError
def log(self, log, **kwargs):
pass
class ConsoleLogger(BaseLogger):
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
print("Logging to console")
def log_images(self, images, **kwargs):
pass
def save_state_dict(self, state_dict, relative_path, **kwargs):
pass
def save_file(self, file_path, **kwargs):
pass
# basic stdout class
class ConsoleTracker(BaseTracker):
def init(self, **config):
print(config)
def log(self, log, **kwargs):
def log(self, log, **kwargs) -> None:
print(log)
def log_images(self, images, **kwargs): # noop for logging images
pass
def save_state_dict(self, state_dict, relative_path, **kwargs):
torch.save(state_dict, str(self.data_path / relative_path))
def save_file(self, file_path, **kwargs):
# This is a no-op for local file systems since it is already saved locally
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
pass
# basic wandb class
def log_file(self, file_path, **kwargs) -> None:
pass
class WandbTracker(BaseTracker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
def log_error(self, error_string, **kwargs) -> None:
print(error_string)
class WandbLogger(BaseLogger):
"""
Logs to a wandb run.
Parameters:
data_path (str): A file path for storing temporary data.
wandb_entity (str): The wandb entity to log to.
wandb_project (str): The wandb project to log to.
wandb_run_id (str): The wandb run id to resume.
wandb_run_name (str): The wandb run name to use.
wandb_resume (bool): Whether to resume a wandb run.
"""
def __init__(self,
data_path: str,
wandb_entity: str,
wandb_project: str,
wandb_run_id: Optional[str] = None,
wandb_run_name: Optional[str] = None,
wandb_resume: bool = False,
**kwargs
):
super().__init__(data_path, **kwargs)
self.entity = wandb_entity
self.project = wandb_project
self.run_id = wandb_run_id
self.run_name = wandb_run_name
self.resume = wandb_resume
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
assert self.project is not None, "wandb_project must be specified for wandb logger"
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
os.environ["WANDB_SILENT"] = "true"
# Initializes the wandb run
init_object = {
"entity": self.entity,
"project": self.project,
"config": {**full_config.dict(), **extra_config}
}
if self.run_name is not None:
init_object['name'] = self.run_name
if self.resume:
assert self.run_id is not None, '`wandb_run_id` must be provided if `wandb_resume` is True'
if self.run_name is not None:
print("You are renaming a run. I hope that is what you intended.")
init_object['resume'] = 'must'
init_object['id'] = self.run_id
def init(self, **config):
self.wandb.init(**config)
self.wandb.init(**init_object)
print(f"Logging to wandb run {self.wandb.run.path}-{self.wandb.run.name}")
def log(self, log, verbose=False, **kwargs):
if verbose:
def log(self, log, **kwargs) -> None:
if self.verbose:
print(log)
self.wandb.log(log, **kwargs)
def log_images(self, images, captions=[], image_section="images", **kwargs):
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
"""
Takes a tensor of images and a list of captions and logs them to wandb.
"""
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
self.log({ image_section: wandb_images }, **kwargs)
def save_state_dict(self, state_dict, relative_path, **kwargs):
"""
Saves a state_dict to disk and uploads it
"""
full_path = str(self.data_path / relative_path)
torch.save(state_dict, full_path)
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
self.wandb.log({ image_section: wandb_images }, **kwargs)
def save_file(self, file_path, base_path=None, **kwargs):
"""
Uploads a file from disk to wandb
"""
def log_file(self, file_path, base_path: Optional[str] = None, **kwargs) -> None:
if base_path is None:
base_path = self.data_path
# Then we take the basepath as the parent of the file_path
base_path = Path(file_path).parent
self.wandb.save(str(file_path), base_path = str(base_path))
def log_error(self, error_string, step=None, **kwargs) -> None:
if self.verbose:
print(error_string)
self.wandb.log({"error": error_string, **kwargs}, step=step)
logger_type_map = {
'console': ConsoleLogger,
'wandb': WandbLogger,
}
def create_logger(logger_type: str, data_path: str, **kwargs) -> BaseLogger:
if logger_type == 'custom':
raise NotImplementedError('Custom loggers are not supported yet. Please use a different logger type.')
try:
logger_class = logger_type_map[logger_type]
except KeyError:
raise ValueError(f'Unknown logger type: {logger_type}. Must be one of {list(logger_type_map.keys())}')
return logger_class(data_path, **kwargs)
class BaseLoader:
"""
An abstract class representing an object that can load a model checkpoint.
Parameters:
data_path (str): A file path for storing temporary data.
"""
def __init__(self, data_path: str, **kwargs):
self.data_path = Path(data_path)
def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError
def recall() -> dict:
raise NotImplementedError
class UrlLoader(BaseLoader):
"""
A loader that downloads the file from a url and loads it
Parameters:
data_path (str): A file path for storing temporary data.
url (str): The url to download the file from.
"""
def __init__(self, data_path: str, url: str, **kwargs):
super().__init__(data_path, **kwargs)
self.url = url
def init(self, logger: BaseLogger, **kwargs) -> None:
# Makes sure the file exists to be downloaded
pass # TODO: Actually implement that
def recall(self) -> dict:
# Download the file
save_path = self.data_path / 'loaded_checkpoint.pth'
urllib.request.urlretrieve(self.url, str(save_path))
# Load the file
return torch.load(str(save_path), map_location='cpu')
class LocalLoader(BaseLoader):
"""
A loader that loads a file from a local path
Parameters:
data_path (str): A file path for storing temporary data.
file_path (str): The path to the file to load.
"""
def __init__(self, data_path: str, file_path: str, **kwargs):
super().__init__(data_path, **kwargs)
self.file_path = Path(file_path)
def init(self, logger: BaseLogger, **kwargs) -> None:
# Makes sure the file exists to be loaded
if not self.file_path.exists():
raise FileNotFoundError(f'Model not found at {self.file_path}')
def recall(self) -> dict:
# Load the file
return torch.load(str(self.file_path), map_location='cpu')
class WandbLoader(BaseLoader):
"""
A loader that loads a model from an existing wandb run
"""
def __init__(self, data_path: str, wandb_file_path: str, wandb_run_path: Optional[str] = None, **kwargs):
super().__init__(data_path, **kwargs)
self.run_path = wandb_run_path
self.file_path = wandb_file_path
def init(self, logger: BaseLogger, **kwargs) -> None:
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
# Make sure the file can be downloaded
if self.wandb.run is not None and self.run_path is None:
self.run_path = self.wandb.run.path
assert self.run_path is not None, 'wandb run was not found to load from. If not using the wandb logger must specify the `wandb_run_path`.'
assert self.run_path is not None, '`wandb_run_path` must be provided for the wandb loader'
assert self.file_path is not None, '`wandb_file_path` must be provided for the wandb loader'
os.environ["WANDB_SILENT"] = "true"
pass # TODO: Actually implement that
def recall(self) -> dict:
file_reference = self.wandb.restore(self.file_path, run_path=self.run_path)
return torch.load(file_reference.name, map_location='cpu')
loader_type_map = {
'url': UrlLoader,
'local': LocalLoader,
'wandb': WandbLoader,
}
def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
if loader_type == 'custom':
raise NotImplementedError('Custom loaders are not supported yet. Please use a different loader type.')
try:
loader_class = loader_type_map[loader_type]
except KeyError:
raise ValueError(f'Unknown loader type: {loader_type}. Must be one of {list(loader_type_map.keys())}')
return loader_class(data_path, **kwargs)
class BaseSaver:
def __init__(self,
data_path: str,
save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
save_best_to: Optional[Union[str, bool]] = 'best.pth',
save_meta_to: str = './',
save_type: str = 'checkpoint',
**kwargs
):
self.data_path = Path(data_path)
self.save_latest_to = save_latest_to
self.saving_latest = save_latest_to is not None and save_latest_to is not False
self.save_best_to = save_best_to
self.saving_best = save_best_to is not None and save_best_to is not False
self.save_meta_to = save_meta_to
self.save_type = save_type
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
assert self.save_meta_to is not None, '`save_meta_to` must be provided'
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError
def save_file(self, local_path: Path, save_path: str, is_best=False, is_latest=False, **kwargs) -> None:
"""
Save a general file under save_meta_to
"""
raise NotImplementedError
class LocalSaver(BaseSaver):
def __init__(self,
data_path: str,
**kwargs
):
super().__init__(data_path, **kwargs)
def init(self, logger: BaseLogger, **kwargs) -> None:
# Makes sure the directory exists to be saved to
print(f"Saving {self.save_type} locally")
if not self.data_path.exists():
self.data_path.mkdir(parents=True)
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
# Copy the file to save_path
save_path_file_name = Path(save_path).name
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
shutil.copy(local_path, save_path)
class WandbSaver(BaseSaver):
def __init__(self, data_path: str, wandb_run_path: Optional[str] = None, **kwargs):
super().__init__(data_path, **kwargs)
self.run_path = wandb_run_path
def init(self, logger: BaseLogger, **kwargs) -> None:
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
os.environ["WANDB_SILENT"] = "true"
# Makes sure that the user can upload tot his run
if self.run_path is not None:
entity, project, run_id = self.run_path.split("/")
self.run = self.wandb.init(entity=entity, project=project, id=run_id)
else:
assert self.wandb.run is not None, 'You must be using the wandb logger if you are saving to wandb and have not set `wandb_run_path`'
self.run = self.wandb.run
# TODO: Now actually check if upload is possible
print(f"Saving to wandb run {self.run.path}-{self.run.name}")
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
# In order to log something in the correct place in wandb, we need to have the same file structure here
save_path_file_name = Path(save_path).name
print(f"Saving {save_path_file_name} {self.save_type} to wandb run {self.run.path}-{self.run.name}")
save_path = Path(self.data_path) / save_path
save_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(local_path, save_path)
self.run.save(str(save_path), base_path = str(self.data_path), policy='now')
class HuggingfaceSaver(BaseSaver):
def __init__(self, data_path: str, huggingface_repo: str, token_path: Optional[str] = None, **kwargs):
super().__init__(data_path, **kwargs)
self.huggingface_repo = huggingface_repo
self.token_path = token_path
def init(self, logger: BaseLogger, **kwargs):
# Makes sure this user can upload to the repo
self.hub = import_or_print_error('huggingface_hub', '`pip install huggingface_hub` to use the huggingface saver')
try:
identity = self.hub.whoami() # Errors if not logged in
# Then we are logged in
except:
# We are not logged in. Use the token_path to set the token.
if not os.path.exists(self.token_path):
raise Exception("Not logged in to huggingface and no token_path specified. Please login with `huggingface-cli login` or if that does not work set the token_path.")
with open(self.token_path, "r") as f:
token = f.read().strip()
self.hub.HfApi.set_access_token(token)
identity = self.hub.whoami()
print(f"Saving to huggingface repo {self.huggingface_repo}")
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
# Saving to huggingface is easy, we just need to upload the file with the correct name
save_path_file_name = Path(save_path).name
print(f"Saving {save_path_file_name} {self.save_type} to huggingface repo {self.huggingface_repo}")
self.hub.upload_file(
path_or_fileobj=str(local_path),
path_in_repo=str(save_path),
repo_id=self.huggingface_repo
)
saver_type_map = {
'local': LocalSaver,
'wandb': WandbSaver,
'huggingface': HuggingfaceSaver
}
def create_saver(saver_type: str, data_path: str, **kwargs) -> BaseSaver:
if saver_type == 'custom':
raise NotImplementedError('Custom savers are not supported yet. Please use a different saver type.')
try:
saver_class = saver_type_map[saver_type]
except KeyError:
raise ValueError(f'Unknown saver type: {saver_type}. Must be one of {list(saver_type_map.keys())}')
return saver_class(data_path, **kwargs)
class Tracker:
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
self.data_path = Path(data_path)
if not dummy_mode:
if overwrite_data_path:
if self.data_path.exists():
shutil.rmtree(self.data_path)
self.data_path.mkdir(parents=True)
else:
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
if not self.data_path.exists():
self.data_path.mkdir(parents=True)
self.logger: BaseLogger = None
self.loader: Optional[BaseLoader] = None
self.savers: List[BaseSaver]= []
self.dummy_mode = dummy_mode
def init(self, full_config: BaseModel, extra_config: dict):
assert self.logger is not None, '`logger` must be set before `init` is called'
if self.dummy_mode:
# The only thing we need is a loader
if self.loader is not None:
self.loader.init(self.logger)
return
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
self.logger.init(full_config, extra_config)
if self.loader is not None:
self.loader.init(self.logger)
for saver in self.savers:
saver.init(self.logger)
def add_logger(self, logger: BaseLogger):
self.logger = logger
def add_loader(self, loader: BaseLoader):
self.loader = loader
def add_saver(self, saver: BaseSaver):
self.savers.append(saver)
def log(self, *args, **kwargs):
if self.dummy_mode:
return
self.logger.log(*args, **kwargs)
def log_images(self, *args, **kwargs):
if self.dummy_mode:
return
self.logger.log_images(*args, **kwargs)
def log_file(self, *args, **kwargs):
if self.dummy_mode:
return
self.logger.log_file(*args, **kwargs)
def save_config(self, current_config_path: str, config_name = 'config.json'):
if self.dummy_mode:
return
# Save the config under config_name in the root folder of data_path
shutil.copy(current_config_path, self.data_path / config_name)
for saver in self.savers:
remote_path = Path(saver.save_meta_to) / config_name
saver.save_file(current_config_path, str(remote_path))
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
"""
Gets the state dict to be saved and writes it to file_path.
If save_type is 'checkpoint', we save the entire trainer state dict.
If save_type is 'model', we save only the model state dict.
"""
assert save_type in ['checkpoint', 'model']
if save_type == 'checkpoint':
trainer.save(file_path, overwrite=True, **kwargs)
elif save_type == 'model':
if isinstance(trainer, DiffusionPriorTrainer):
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
state_dict = trainer.unwrap_model(prior).state_dict()
torch.save(state_dict, file_path)
elif isinstance(trainer, DecoderTrainer):
decoder = trainer.accelerator.unwrap_model(trainer.decoder)
if trainer.use_ema:
trainable_unets = decoder.unets
decoder.unets = trainer.unets # Swap EMA unets in
state_dict = decoder.state_dict()
decoder.unets = trainable_unets # Swap back
else:
state_dict = decoder.state_dict()
torch.save(state_dict, file_path)
else:
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
return Path(file_path)
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):
if self.dummy_mode:
return
if not is_best and not is_latest:
# Nothing to do
return
# Save the checkpoint and model to data_path
checkpoint_path = self.data_path / 'checkpoint.pth'
self._save_state_dict(trainer, 'checkpoint', checkpoint_path, **kwargs)
model_path = self.data_path / 'model.pth'
self._save_state_dict(trainer, 'model', model_path, **kwargs)
print("Saved cached models")
# Call the save methods on the savers
for saver in self.savers:
local_path = checkpoint_path if saver.save_type == 'checkpoint' else model_path
if saver.saving_latest and is_latest:
latest_checkpoint_path = saver.save_latest_to.format(**kwargs)
try:
saver.save_file(local_path, latest_checkpoint_path, is_latest=True, **kwargs)
except Exception as e:
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
print(f'Error saving checkpoint: {e}')
if saver.saving_best and is_best:
best_checkpoint_path = saver.save_best_to.format(**kwargs)
try:
saver.save_file(local_path, best_checkpoint_path, is_best=True, **kwargs)
except Exception as e:
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
print(f'Error saving checkpoint: {e}')
def recall(self):
if self.loader is not None:
return self.loader.recall()
else:
raise ValueError('No loader specified')

View File

@@ -15,6 +15,7 @@ from dalle2_pytorch.dalle2_pytorch import (
DiffusionPriorNetwork,
XClipAdapter
)
from dalle2_pytorch.trackers import Tracker, create_loader, create_logger, create_saver
# helper functions
@@ -44,13 +45,66 @@ class TrainSplitConfig(BaseModel):
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
return fields
class TrackerLogConfig(BaseModel):
log_type: str = 'console'
verbose: bool = False
class Config:
# Each individual log type has it's own arguments that will be passed through the config
extra = "allow"
def create(self, data_path: str):
kwargs = self.dict()
return create_logger(self.log_type, data_path, **kwargs)
class TrackerLoadConfig(BaseModel):
load_from: Optional[str] = None
class Config:
extra = "allow"
def create(self, data_path: str):
kwargs = self.dict()
if self.load_from is None:
return None
return create_loader(self.load_from, data_path, **kwargs)
class TrackerSaveConfig(BaseModel):
save_to: str = 'local'
save_all: bool = False
save_latest: bool = True
save_best: bool = True
class Config:
extra = "allow"
def create(self, data_path: str):
kwargs = self.dict()
return create_saver(self.save_to, data_path, **kwargs)
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
data_path: str = '.tracker_data'
overwrite_data_path: bool = False
log: TrackerLogConfig
load: Optional[TrackerLoadConfig]
save: Union[List[TrackerSaveConfig], TrackerSaveConfig]
def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
tracker = Tracker(self.data_path, dummy_mode=dummy_mode, overwrite_data_path=self.overwrite_data_path)
# Add the logger
tracker.add_logger(self.log.create(self.data_path))
# Add the loader
if self.load is not None:
tracker.add_loader(self.load.create(self.data_path))
# Add the saver or savers
if isinstance(self.save, list):
for save_config in self.save:
tracker.add_saver(save_config.create(self.data_path))
else:
tracker.add_saver(self.save.create(self.data_path))
# Initialize all the components and verify that all data is valid
tracker.init(full_config, extra_config)
return tracker
# diffusion prior pydantic classes
@@ -238,6 +292,7 @@ class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: SingularOrIterable(float) = 1e-4
wd: SingularOrIterable(float) = 0.01
find_unused_parameters: bool = True
max_grad_norm: SingularOrIterable(float) = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
@@ -247,9 +302,6 @@ class DecoderTrainConfig(BaseModel):
use_ema: bool = True
ema_beta: float = 0.999
amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint
save_best: bool = True # Whether to save the best checkpoint
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
@@ -271,7 +323,6 @@ class TrainDecoderConfig(BaseModel):
train: DecoderTrainConfig
evaluate: DecoderEvaluateConfig
tracker: TrackerConfig
load: DecoderLoadConfig
seed: int = 0
@classmethod
@@ -294,17 +345,17 @@ class TrainDecoderConfig(BaseModel):
img_emb_url = data_config.img_embeddings_url
text_emb_url = data_config.text_embeddings_url
if using_text_embeddings:
if using_text_encodings:
# Then we need some way to get the embeddings
assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
if using_clip:
if using_text_embeddings:
if using_text_encodings:
assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
else:
assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
if text_emb_url:
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
assert using_text_encodings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
return values

View File

@@ -505,12 +505,7 @@ class DecoderTrainer(nn.Module):
self.accelerator.save(save_obj, str(path))
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path), map_location = 'cpu')
def load_state_dict(self, loaded_obj, only_model = False, strict = True):
if version.parse(__version__) != version.parse(loaded_obj['version']):
self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
@@ -530,6 +525,14 @@ class DecoderTrainer(nn.Module):
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path), map_location = 'cpu')
self.load_state_dict(loaded_obj, only_model = only_model, strict = strict)
return loaded_obj
@property

View File

@@ -1,6 +1,11 @@
import time
import importlib
# helper functions
def exists(val):
return val is not None
# time helpers
class Timer:

View File

@@ -1,11 +1,12 @@
from pathlib import Path
from typing import List
from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.trackers import Tracker
from dalle2_pytorch.train_configs import DecoderConfig, TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.dalle2_pytorch import resize_image_to
from dalle2_pytorch.dalle2_pytorch import Decoder, resize_image_to
from clip import tokenize
import torchvision
@@ -239,42 +240,33 @@ def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=Fa
metrics[metric_name] = metrics_tensor[i].item()
return metrics
def save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, relative_paths):
def save_trainer(tracker: Tracker, trainer: DecoderTrainer, epoch: int, sample: int, next_task: str, validation_losses: List[float], samples_seen: int, is_latest=True, is_best=False):
"""
Logs the model with an appropriate method depending on the tracker
"""
if isinstance(relative_paths, str):
relative_paths = [relative_paths]
for relative_path in relative_paths:
local_path = str(tracker.data_path / relative_path)
trainer.save(local_path, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses)
tracker.save_file(local_path)
tracker.save(trainer, is_best=is_best, is_latest=is_latest, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses, samples_seen=samples_seen)
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
def recall_trainer(tracker: Tracker, trainer: DecoderTrainer):
"""
Loads the model with an appropriate method depending on the tracker
"""
trainer.accelerator.print(print_ribbon(f"Loading model from {recall_source}"))
local_filepath = tracker.recall_file(recall_source, **load_config)
state_dict = trainer.load(local_filepath)
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0)
trainer.accelerator.print(print_ribbon(f"Loading model from {type(tracker.loader).__name__}"))
state_dict = tracker.recall()
trainer.load_state_dict(state_dict, only_model=False, strict=True)
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0), state_dict.get("samples_seen", 0)
def train(
dataloaders,
decoder,
accelerator,
tracker,
decoder: Decoder,
accelerator: Accelerator,
tracker: Tracker,
inference_device,
load_config=None,
evaluate_config=None,
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
validation_samples = None,
epochs = 20,
n_sample_images = 5,
save_every_n_samples = 100000,
save_all=False,
save_latest=True,
save_best=True,
unet_training_mask=None,
condition_on_text_encodings=False,
**kwargs
@@ -299,13 +291,13 @@ def train(
val_sample = 0
step = lambda: int(trainer.step.item())
if exists(load_config) and exists(load_config.source):
start_epoch, validation_losses, next_task, recalled_sample = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config.dict())
if tracker.loader is not None:
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
if next_task == 'train':
sample = recalled_sample
if next_task == 'val':
val_sample = recalled_sample
accelerator.print(f"Loaded model from {load_config.source} on epoch {start_epoch} with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
accelerator.print(f"Loaded model from {type(tracker.loader).__name__} on epoch {start_epoch} having seen {samples_seen} samples with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
trainer.to(device=inference_device)
@@ -399,19 +391,14 @@ def train(
}
if is_master:
tracker.log(log_data, step=step(), verbose=True)
tracker.log(log_data, step=step())
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
print("Saving snapshot")
last_snapshot = sample
# We need to know where the model should be saved
save_paths = []
if save_latest:
save_paths.append("latest.pth")
if save_all:
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step()}.pth")
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
@@ -486,7 +473,7 @@ def train(
if is_master:
unet_average_val_loss = all_average_val_losses.mean(dim=0)
val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
tracker.log(val_loss_map, step=step(), verbose=True)
tracker.log(val_loss_map, step=step())
next_task = 'eval'
if next_task == 'eval':
@@ -494,7 +481,7 @@ def train(
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
if is_master:
tracker.log(evaluation, step=step(), verbose=True)
tracker.log(evaluation, step=step())
next_task = 'sample'
val_sample = 0
@@ -509,22 +496,16 @@ def train(
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
# Get the same paths
save_paths = []
if save_latest:
save_paths.append("latest.pth")
is_best = False
if all_average_val_losses is not None:
average_loss = all_average_val_losses.mean(dim=0).item()
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
save_paths.append("best.pth")
if len(validation_losses) == 0 or average_loss < min(validation_losses):
is_best = True
validation_losses.append(average_loss)
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen, is_best=is_best)
next_task = 'train'
def create_tracker(accelerator, config, config_path, tracker_type=None, data_path=None):
"""
Creates a tracker of the specified type and initializes special features based on the full config
"""
def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_path: str, dummy: bool = False) -> Tracker:
tracker_config = config.tracker
accelerator_config = {
"Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
@@ -532,40 +513,16 @@ def create_tracker(accelerator, config, config_path, tracker_type=None, data_pat
"NumProcesses": accelerator.num_processes,
"MixedPrecision": accelerator.mixed_precision
}
init_config = { "config": {**config.dict(), **accelerator_config} }
data_path = data_path or tracker_config.data_path
tracker_type = tracker_type or tracker_config.tracker_type
if tracker_type == "dummy":
tracker = DummyTracker(data_path)
tracker.init(**init_config)
elif tracker_type == "console":
tracker = ConsoleTracker(data_path)
tracker.init(**init_config)
elif tracker_type == "wandb":
# We need to initialize the resume state here
load_config = config.load
if load_config.source == "wandb" and load_config.resume:
# Then we are resuming the run load_config["run_path"]
run_id = load_config.run_path.split("/")[-1]
init_config["id"] = run_id
init_config["resume"] = "must"
init_config["entity"] = tracker_config.wandb_entity
init_config["project"] = tracker_config.wandb_project
tracker = WandbTracker(data_path)
tracker.init(**init_config)
tracker.save_file(str(config_path.absolute()), str(config_path.parent.absolute()))
else:
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
tracker.save_config(config_path, config_name='decoder_config.json')
return tracker
def initialize_training(config, config_path):
def initialize_training(config: TrainDecoderConfig, config_path):
# Make sure if we are not loading, distributed models are initialized to the same values
torch.manual_seed(config.seed)
# Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
# Set up data
@@ -592,7 +549,7 @@ def initialize_training(config, config_path):
num_parameters = sum(p.numel() for p in decoder.parameters())
# Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
has_img_embeddings = config.data.img_embeddings_url is not None
has_text_embeddings = config.data.text_embeddings_url is not None
@@ -622,7 +579,6 @@ def initialize_training(config, config_path):
train(dataloaders, decoder, accelerator,
tracker=tracker,
inference_device=accelerator.device,
load_config=config.load,
evaluate_config=config.evaluate,
condition_on_text_encodings=conditioning_on_text,
**config.train.dict(),