mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
Overhauled the tracker system (#172)
* Overhauled the tracker system Separated the logging and saving capabilities Changed creation to be consistent and initializing behavior to be defined by a class initializer instead of in the training script Added class separation between different types of loaders and savers to make the system more verbose * Changed the saver system to only save the checkpoint once * Added better error handling for saving checkpoints * Fixed an error where wandb would error when passed arbitrary kwargs * Fixed variable naming issues for improved saver Added more logging during long pauses * Fixed which methods need to be dummy to immediatly return Added the ability to set whether you find unused parameters * Added more logging for when a wandb loader fails
This commit is contained in:
@@ -1,12 +1,15 @@
|
||||
import urllib.request
|
||||
import os
|
||||
from pathlib import Path
|
||||
import importlib
|
||||
import shutil
|
||||
from itertools import zip_longest
|
||||
from typing import Optional, List, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from dalle2_pytorch.utils import import_or_print_error
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||
|
||||
# constants
|
||||
|
||||
@@ -27,126 +30,484 @@ def load_wandb_file(run_path, file_path, **kwargs):
|
||||
def load_local_file(file_path, **kwargs):
|
||||
return file_path
|
||||
|
||||
# base class
|
||||
|
||||
class BaseTracker(nn.Module):
|
||||
def __init__(self, data_path = DEFAULT_DATA_PATH):
|
||||
super().__init__()
|
||||
class BaseLogger:
|
||||
"""
|
||||
An abstract class representing an object that can log data.
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
verbose (bool): Whether of not to always print logs to the console.
|
||||
"""
|
||||
def __init__(self, data_path: str, verbose: bool = False, **kwargs):
|
||||
self.data_path = Path(data_path)
|
||||
self.data_path.mkdir(parents = True, exist_ok = True)
|
||||
self.verbose = verbose
|
||||
|
||||
def init(self, config, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def log_images(self, images, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def recall_state_dict(self, recall_source, *args, **kwargs):
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
"""
|
||||
Loads a state dict from any source.
|
||||
Since a user may wish to load a model from a different source than their own tracker (i.e. tracking using wandb but recalling from disk),
|
||||
this should not be linked to any individual tracker.
|
||||
Initializes the logger.
|
||||
Errors if the logger is invalid.
|
||||
"""
|
||||
# TODO: Pull this into a dict or something similar so that we can add more sources without having a massive switch statement
|
||||
if recall_source == 'wandb':
|
||||
return torch.load(load_wandb_file(*args, **kwargs))
|
||||
elif recall_source == 'local':
|
||||
return torch.load(load_local_file(*args, **kwargs))
|
||||
else:
|
||||
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
||||
|
||||
def save_file(self, file_path, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def recall_file(self, recall_source, *args, **kwargs):
|
||||
if recall_source == 'wandb':
|
||||
return load_wandb_file(*args, **kwargs)
|
||||
elif recall_source == 'local':
|
||||
return load_local_file(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError('`recall_source` must be one of `wandb` or `local`')
|
||||
def log(self, log, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
# Tracker that no-ops all calls except for recall
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
class DummyTracker(BaseTracker):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def log_file(self, file_path, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def init(self, config, **kwargs):
|
||||
pass
|
||||
def log_error(self, error_string, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
pass
|
||||
class ConsoleLogger(BaseLogger):
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
print("Logging to console")
|
||||
|
||||
def log_images(self, images, **kwargs):
|
||||
pass
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
pass
|
||||
|
||||
def save_file(self, file_path, **kwargs):
|
||||
pass
|
||||
|
||||
# basic stdout class
|
||||
|
||||
class ConsoleTracker(BaseTracker):
|
||||
def init(self, **config):
|
||||
print(config)
|
||||
|
||||
def log(self, log, **kwargs):
|
||||
def log(self, log, **kwargs) -> None:
|
||||
print(log)
|
||||
|
||||
def log_images(self, images, **kwargs): # noop for logging images
|
||||
pass
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
torch.save(state_dict, str(self.data_path / relative_path))
|
||||
|
||||
def save_file(self, file_path, **kwargs):
|
||||
# This is a no-op for local file systems since it is already saved locally
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
# basic wandb class
|
||||
def log_file(self, file_path, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
class WandbTracker(BaseTracker):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb experiment tracker')
|
||||
def log_error(self, error_string, **kwargs) -> None:
|
||||
print(error_string)
|
||||
|
||||
class WandbLogger(BaseLogger):
|
||||
"""
|
||||
Logs to a wandb run.
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
wandb_entity (str): The wandb entity to log to.
|
||||
wandb_project (str): The wandb project to log to.
|
||||
wandb_run_id (str): The wandb run id to resume.
|
||||
wandb_run_name (str): The wandb run name to use.
|
||||
wandb_resume (bool): Whether to resume a wandb run.
|
||||
"""
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
wandb_entity: str,
|
||||
wandb_project: str,
|
||||
wandb_run_id: Optional[str] = None,
|
||||
wandb_run_name: Optional[str] = None,
|
||||
wandb_resume: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.entity = wandb_entity
|
||||
self.project = wandb_project
|
||||
self.run_id = wandb_run_id
|
||||
self.run_name = wandb_run_name
|
||||
self.resume = wandb_resume
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
|
||||
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
|
||||
assert self.project is not None, "wandb_project must be specified for wandb logger"
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
# Initializes the wandb run
|
||||
init_object = {
|
||||
"entity": self.entity,
|
||||
"project": self.project,
|
||||
"config": {**full_config.dict(), **extra_config}
|
||||
}
|
||||
if self.run_name is not None:
|
||||
init_object['name'] = self.run_name
|
||||
if self.resume:
|
||||
assert self.run_id is not None, '`wandb_run_id` must be provided if `wandb_resume` is True'
|
||||
if self.run_name is not None:
|
||||
print("You are renaming a run. I hope that is what you intended.")
|
||||
init_object['resume'] = 'must'
|
||||
init_object['id'] = self.run_id
|
||||
|
||||
def init(self, **config):
|
||||
self.wandb.init(**config)
|
||||
self.wandb.init(**init_object)
|
||||
print(f"Logging to wandb run {self.wandb.run.path}-{self.wandb.run.name}")
|
||||
|
||||
def log(self, log, verbose=False, **kwargs):
|
||||
if verbose:
|
||||
def log(self, log, **kwargs) -> None:
|
||||
if self.verbose:
|
||||
print(log)
|
||||
self.wandb.log(log, **kwargs)
|
||||
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs):
|
||||
def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
|
||||
"""
|
||||
Takes a tensor of images and a list of captions and logs them to wandb.
|
||||
"""
|
||||
wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
|
||||
self.log({ image_section: wandb_images }, **kwargs)
|
||||
|
||||
def save_state_dict(self, state_dict, relative_path, **kwargs):
|
||||
"""
|
||||
Saves a state_dict to disk and uploads it
|
||||
"""
|
||||
full_path = str(self.data_path / relative_path)
|
||||
torch.save(state_dict, full_path)
|
||||
self.wandb.save(full_path, base_path = str(self.data_path)) # Upload and keep relative to data_path
|
||||
self.wandb.log({ image_section: wandb_images }, **kwargs)
|
||||
|
||||
def save_file(self, file_path, base_path=None, **kwargs):
|
||||
"""
|
||||
Uploads a file from disk to wandb
|
||||
"""
|
||||
def log_file(self, file_path, base_path: Optional[str] = None, **kwargs) -> None:
|
||||
if base_path is None:
|
||||
base_path = self.data_path
|
||||
# Then we take the basepath as the parent of the file_path
|
||||
base_path = Path(file_path).parent
|
||||
self.wandb.save(str(file_path), base_path = str(base_path))
|
||||
|
||||
def log_error(self, error_string, step=None, **kwargs) -> None:
|
||||
if self.verbose:
|
||||
print(error_string)
|
||||
self.wandb.log({"error": error_string, **kwargs}, step=step)
|
||||
|
||||
logger_type_map = {
|
||||
'console': ConsoleLogger,
|
||||
'wandb': WandbLogger,
|
||||
}
|
||||
def create_logger(logger_type: str, data_path: str, **kwargs) -> BaseLogger:
|
||||
if logger_type == 'custom':
|
||||
raise NotImplementedError('Custom loggers are not supported yet. Please use a different logger type.')
|
||||
try:
|
||||
logger_class = logger_type_map[logger_type]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown logger type: {logger_type}. Must be one of {list(logger_type_map.keys())}')
|
||||
return logger_class(data_path, **kwargs)
|
||||
|
||||
class BaseLoader:
|
||||
"""
|
||||
An abstract class representing an object that can load a model checkpoint.
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
"""
|
||||
def __init__(self, data_path: str, **kwargs):
|
||||
self.data_path = Path(data_path)
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def recall() -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
class UrlLoader(BaseLoader):
|
||||
"""
|
||||
A loader that downloads the file from a url and loads it
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
url (str): The url to download the file from.
|
||||
"""
|
||||
def __init__(self, data_path: str, url: str, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.url = url
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
# Makes sure the file exists to be downloaded
|
||||
pass # TODO: Actually implement that
|
||||
|
||||
def recall(self) -> dict:
|
||||
# Download the file
|
||||
save_path = self.data_path / 'loaded_checkpoint.pth'
|
||||
urllib.request.urlretrieve(self.url, str(save_path))
|
||||
# Load the file
|
||||
return torch.load(str(save_path), map_location='cpu')
|
||||
|
||||
|
||||
class LocalLoader(BaseLoader):
|
||||
"""
|
||||
A loader that loads a file from a local path
|
||||
Parameters:
|
||||
data_path (str): A file path for storing temporary data.
|
||||
file_path (str): The path to the file to load.
|
||||
"""
|
||||
def __init__(self, data_path: str, file_path: str, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.file_path = Path(file_path)
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
# Makes sure the file exists to be loaded
|
||||
if not self.file_path.exists():
|
||||
raise FileNotFoundError(f'Model not found at {self.file_path}')
|
||||
|
||||
def recall(self) -> dict:
|
||||
# Load the file
|
||||
return torch.load(str(self.file_path), map_location='cpu')
|
||||
|
||||
class WandbLoader(BaseLoader):
|
||||
"""
|
||||
A loader that loads a model from an existing wandb run
|
||||
"""
|
||||
def __init__(self, data_path: str, wandb_file_path: str, wandb_run_path: Optional[str] = None, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.run_path = wandb_run_path
|
||||
self.file_path = wandb_file_path
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
|
||||
# Make sure the file can be downloaded
|
||||
if self.wandb.run is not None and self.run_path is None:
|
||||
self.run_path = self.wandb.run.path
|
||||
assert self.run_path is not None, 'wandb run was not found to load from. If not using the wandb logger must specify the `wandb_run_path`.'
|
||||
assert self.run_path is not None, '`wandb_run_path` must be provided for the wandb loader'
|
||||
assert self.file_path is not None, '`wandb_file_path` must be provided for the wandb loader'
|
||||
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
pass # TODO: Actually implement that
|
||||
|
||||
def recall(self) -> dict:
|
||||
file_reference = self.wandb.restore(self.file_path, run_path=self.run_path)
|
||||
return torch.load(file_reference.name, map_location='cpu')
|
||||
|
||||
loader_type_map = {
|
||||
'url': UrlLoader,
|
||||
'local': LocalLoader,
|
||||
'wandb': WandbLoader,
|
||||
}
|
||||
def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
|
||||
if loader_type == 'custom':
|
||||
raise NotImplementedError('Custom loaders are not supported yet. Please use a different loader type.')
|
||||
try:
|
||||
loader_class = loader_type_map[loader_type]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown loader type: {loader_type}. Must be one of {list(loader_type_map.keys())}')
|
||||
return loader_class(data_path, **kwargs)
|
||||
|
||||
class BaseSaver:
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
save_latest_to: Optional[Union[str, bool]] = 'latest.pth',
|
||||
save_best_to: Optional[Union[str, bool]] = 'best.pth',
|
||||
save_meta_to: str = './',
|
||||
save_type: str = 'checkpoint',
|
||||
**kwargs
|
||||
):
|
||||
self.data_path = Path(data_path)
|
||||
self.save_latest_to = save_latest_to
|
||||
self.saving_latest = save_latest_to is not None and save_latest_to is not False
|
||||
self.save_best_to = save_best_to
|
||||
self.saving_best = save_best_to is not None and save_best_to is not False
|
||||
self.save_meta_to = save_meta_to
|
||||
self.save_type = save_type
|
||||
assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
|
||||
assert self.save_meta_to is not None, '`save_meta_to` must be provided'
|
||||
assert self.saving_latest or self.saving_best, '`save_latest_to` or `save_best_to` must be provided'
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def save_file(self, local_path: Path, save_path: str, is_best=False, is_latest=False, **kwargs) -> None:
|
||||
"""
|
||||
Save a general file under save_meta_to
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
class LocalSaver(BaseSaver):
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(data_path, **kwargs)
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
# Makes sure the directory exists to be saved to
|
||||
print(f"Saving {self.save_type} locally")
|
||||
if not self.data_path.exists():
|
||||
self.data_path.mkdir(parents=True)
|
||||
|
||||
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
|
||||
# Copy the file to save_path
|
||||
save_path_file_name = Path(save_path).name
|
||||
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
|
||||
shutil.copy(local_path, save_path)
|
||||
|
||||
class WandbSaver(BaseSaver):
|
||||
def __init__(self, data_path: str, wandb_run_path: Optional[str] = None, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.run_path = wandb_run_path
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs) -> None:
|
||||
self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
# Makes sure that the user can upload tot his run
|
||||
if self.run_path is not None:
|
||||
entity, project, run_id = self.run_path.split("/")
|
||||
self.run = self.wandb.init(entity=entity, project=project, id=run_id)
|
||||
else:
|
||||
assert self.wandb.run is not None, 'You must be using the wandb logger if you are saving to wandb and have not set `wandb_run_path`'
|
||||
self.run = self.wandb.run
|
||||
# TODO: Now actually check if upload is possible
|
||||
print(f"Saving to wandb run {self.run.path}-{self.run.name}")
|
||||
|
||||
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
|
||||
# In order to log something in the correct place in wandb, we need to have the same file structure here
|
||||
save_path_file_name = Path(save_path).name
|
||||
print(f"Saving {save_path_file_name} {self.save_type} to wandb run {self.run.path}-{self.run.name}")
|
||||
save_path = Path(self.data_path) / save_path
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(local_path, save_path)
|
||||
self.run.save(str(save_path), base_path = str(self.data_path), policy='now')
|
||||
|
||||
class HuggingfaceSaver(BaseSaver):
|
||||
def __init__(self, data_path: str, huggingface_repo: str, token_path: Optional[str] = None, **kwargs):
|
||||
super().__init__(data_path, **kwargs)
|
||||
self.huggingface_repo = huggingface_repo
|
||||
self.token_path = token_path
|
||||
|
||||
def init(self, logger: BaseLogger, **kwargs):
|
||||
# Makes sure this user can upload to the repo
|
||||
self.hub = import_or_print_error('huggingface_hub', '`pip install huggingface_hub` to use the huggingface saver')
|
||||
try:
|
||||
identity = self.hub.whoami() # Errors if not logged in
|
||||
# Then we are logged in
|
||||
except:
|
||||
# We are not logged in. Use the token_path to set the token.
|
||||
if not os.path.exists(self.token_path):
|
||||
raise Exception("Not logged in to huggingface and no token_path specified. Please login with `huggingface-cli login` or if that does not work set the token_path.")
|
||||
with open(self.token_path, "r") as f:
|
||||
token = f.read().strip()
|
||||
self.hub.HfApi.set_access_token(token)
|
||||
identity = self.hub.whoami()
|
||||
print(f"Saving to huggingface repo {self.huggingface_repo}")
|
||||
|
||||
def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
|
||||
# Saving to huggingface is easy, we just need to upload the file with the correct name
|
||||
save_path_file_name = Path(save_path).name
|
||||
print(f"Saving {save_path_file_name} {self.save_type} to huggingface repo {self.huggingface_repo}")
|
||||
self.hub.upload_file(
|
||||
path_or_fileobj=str(local_path),
|
||||
path_in_repo=str(save_path),
|
||||
repo_id=self.huggingface_repo
|
||||
)
|
||||
|
||||
saver_type_map = {
|
||||
'local': LocalSaver,
|
||||
'wandb': WandbSaver,
|
||||
'huggingface': HuggingfaceSaver
|
||||
}
|
||||
def create_saver(saver_type: str, data_path: str, **kwargs) -> BaseSaver:
|
||||
if saver_type == 'custom':
|
||||
raise NotImplementedError('Custom savers are not supported yet. Please use a different saver type.')
|
||||
try:
|
||||
saver_class = saver_type_map[saver_type]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown saver type: {saver_type}. Must be one of {list(saver_type_map.keys())}')
|
||||
return saver_class(data_path, **kwargs)
|
||||
|
||||
|
||||
class Tracker:
|
||||
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
|
||||
self.data_path = Path(data_path)
|
||||
if not dummy_mode:
|
||||
if overwrite_data_path:
|
||||
if self.data_path.exists():
|
||||
shutil.rmtree(self.data_path)
|
||||
self.data_path.mkdir(parents=True)
|
||||
else:
|
||||
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
|
||||
if not self.data_path.exists():
|
||||
self.data_path.mkdir(parents=True)
|
||||
self.logger: BaseLogger = None
|
||||
self.loader: Optional[BaseLoader] = None
|
||||
self.savers: List[BaseSaver]= []
|
||||
self.dummy_mode = dummy_mode
|
||||
|
||||
def init(self, full_config: BaseModel, extra_config: dict):
|
||||
assert self.logger is not None, '`logger` must be set before `init` is called'
|
||||
if self.dummy_mode:
|
||||
# The only thing we need is a loader
|
||||
if self.loader is not None:
|
||||
self.loader.init(self.logger)
|
||||
return
|
||||
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
|
||||
self.logger.init(full_config, extra_config)
|
||||
if self.loader is not None:
|
||||
self.loader.init(self.logger)
|
||||
for saver in self.savers:
|
||||
saver.init(self.logger)
|
||||
|
||||
def add_logger(self, logger: BaseLogger):
|
||||
self.logger = logger
|
||||
|
||||
def add_loader(self, loader: BaseLoader):
|
||||
self.loader = loader
|
||||
|
||||
def add_saver(self, saver: BaseSaver):
|
||||
self.savers.append(saver)
|
||||
|
||||
def log(self, *args, **kwargs):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
self.logger.log(*args, **kwargs)
|
||||
|
||||
def log_images(self, *args, **kwargs):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
self.logger.log_images(*args, **kwargs)
|
||||
|
||||
def log_file(self, *args, **kwargs):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
self.logger.log_file(*args, **kwargs)
|
||||
|
||||
def save_config(self, current_config_path: str, config_name = 'config.json'):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
# Save the config under config_name in the root folder of data_path
|
||||
shutil.copy(current_config_path, self.data_path / config_name)
|
||||
for saver in self.savers:
|
||||
remote_path = Path(saver.save_meta_to) / config_name
|
||||
saver.save_file(current_config_path, str(remote_path))
|
||||
|
||||
def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
|
||||
"""
|
||||
Gets the state dict to be saved and writes it to file_path.
|
||||
If save_type is 'checkpoint', we save the entire trainer state dict.
|
||||
If save_type is 'model', we save only the model state dict.
|
||||
"""
|
||||
assert save_type in ['checkpoint', 'model']
|
||||
if save_type == 'checkpoint':
|
||||
trainer.save(file_path, overwrite=True, **kwargs)
|
||||
elif save_type == 'model':
|
||||
if isinstance(trainer, DiffusionPriorTrainer):
|
||||
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
||||
state_dict = trainer.unwrap_model(prior).state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
elif isinstance(trainer, DecoderTrainer):
|
||||
decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
||||
if trainer.use_ema:
|
||||
trainable_unets = decoder.unets
|
||||
decoder.unets = trainer.unets # Swap EMA unets in
|
||||
state_dict = decoder.state_dict()
|
||||
decoder.unets = trainable_unets # Swap back
|
||||
else:
|
||||
state_dict = decoder.state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
else:
|
||||
raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
|
||||
return Path(file_path)
|
||||
|
||||
def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):
|
||||
if self.dummy_mode:
|
||||
return
|
||||
if not is_best and not is_latest:
|
||||
# Nothing to do
|
||||
return
|
||||
# Save the checkpoint and model to data_path
|
||||
checkpoint_path = self.data_path / 'checkpoint.pth'
|
||||
self._save_state_dict(trainer, 'checkpoint', checkpoint_path, **kwargs)
|
||||
model_path = self.data_path / 'model.pth'
|
||||
self._save_state_dict(trainer, 'model', model_path, **kwargs)
|
||||
print("Saved cached models")
|
||||
# Call the save methods on the savers
|
||||
for saver in self.savers:
|
||||
local_path = checkpoint_path if saver.save_type == 'checkpoint' else model_path
|
||||
if saver.saving_latest and is_latest:
|
||||
latest_checkpoint_path = saver.save_latest_to.format(**kwargs)
|
||||
try:
|
||||
saver.save_file(local_path, latest_checkpoint_path, is_latest=True, **kwargs)
|
||||
except Exception as e:
|
||||
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||
print(f'Error saving checkpoint: {e}')
|
||||
if saver.saving_best and is_best:
|
||||
best_checkpoint_path = saver.save_best_to.format(**kwargs)
|
||||
try:
|
||||
saver.save_file(local_path, best_checkpoint_path, is_best=True, **kwargs)
|
||||
except Exception as e:
|
||||
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
|
||||
print(f'Error saving checkpoint: {e}')
|
||||
|
||||
def recall(self):
|
||||
if self.loader is not None:
|
||||
return self.loader.recall()
|
||||
else:
|
||||
raise ValueError('No loader specified')
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user