Overhauled the tracker system (#172)

* Overhauled the tracker system
Separated the logging and saving capabilities
Changed creation to be consistent and initializing behavior to be defined by a class initializer instead of in the training script
Added class separation between different types of loaders and savers to make the system more verbose

* Changed the saver system to only save the checkpoint once

* Added better error handling for saving checkpoints

* Fixed an error where wandb would error when passed arbitrary kwargs

* Fixed variable naming issues for improved saver
Added more logging during long pauses

* Fixed which methods need to be dummy to immediatly return
Added the ability to set whether you find unused parameters

* Added more logging for when a wandb loader fails
This commit is contained in:
Aidan Dempster
2022-07-01 12:39:40 -04:00
committed by GitHub
parent 7b0edf9e42
commit 27b0f7ca0d
7 changed files with 662 additions and 212 deletions

View File

@@ -15,6 +15,7 @@ from dalle2_pytorch.dalle2_pytorch import (
DiffusionPriorNetwork,
XClipAdapter
)
from dalle2_pytorch.trackers import Tracker, create_loader, create_logger, create_saver
# helper functions
@@ -44,13 +45,66 @@ class TrainSplitConfig(BaseModel):
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
return fields
class TrackerLogConfig(BaseModel):
log_type: str = 'console'
verbose: bool = False
class Config:
# Each individual log type has it's own arguments that will be passed through the config
extra = "allow"
def create(self, data_path: str):
kwargs = self.dict()
return create_logger(self.log_type, data_path, **kwargs)
class TrackerLoadConfig(BaseModel):
load_from: Optional[str] = None
class Config:
extra = "allow"
def create(self, data_path: str):
kwargs = self.dict()
if self.load_from is None:
return None
return create_loader(self.load_from, data_path, **kwargs)
class TrackerSaveConfig(BaseModel):
save_to: str = 'local'
save_all: bool = False
save_latest: bool = True
save_best: bool = True
class Config:
extra = "allow"
def create(self, data_path: str):
kwargs = self.dict()
return create_saver(self.save_to, data_path, **kwargs)
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
data_path: str = '.tracker_data'
overwrite_data_path: bool = False
log: TrackerLogConfig
load: Optional[TrackerLoadConfig]
save: Union[List[TrackerSaveConfig], TrackerSaveConfig]
def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
tracker = Tracker(self.data_path, dummy_mode=dummy_mode, overwrite_data_path=self.overwrite_data_path)
# Add the logger
tracker.add_logger(self.log.create(self.data_path))
# Add the loader
if self.load is not None:
tracker.add_loader(self.load.create(self.data_path))
# Add the saver or savers
if isinstance(self.save, list):
for save_config in self.save:
tracker.add_saver(save_config.create(self.data_path))
else:
tracker.add_saver(self.save.create(self.data_path))
# Initialize all the components and verify that all data is valid
tracker.init(full_config, extra_config)
return tracker
# diffusion prior pydantic classes
@@ -238,6 +292,7 @@ class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: SingularOrIterable(float) = 1e-4
wd: SingularOrIterable(float) = 0.01
find_unused_parameters: bool = True
max_grad_norm: SingularOrIterable(float) = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
@@ -247,9 +302,6 @@ class DecoderTrainConfig(BaseModel):
use_ema: bool = True
ema_beta: float = 0.999
amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint
save_best: bool = True # Whether to save the best checkpoint
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
@@ -271,7 +323,6 @@ class TrainDecoderConfig(BaseModel):
train: DecoderTrainConfig
evaluate: DecoderEvaluateConfig
tracker: TrackerConfig
load: DecoderLoadConfig
seed: int = 0
@classmethod
@@ -294,17 +345,17 @@ class TrainDecoderConfig(BaseModel):
img_emb_url = data_config.img_embeddings_url
text_emb_url = data_config.text_embeddings_url
if using_text_embeddings:
if using_text_encodings:
# Then we need some way to get the embeddings
assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
if using_clip:
if using_text_embeddings:
if using_text_encodings:
assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
else:
assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
if text_emb_url:
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
assert using_text_encodings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
return values