mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-22 11:04:21 +01:00
Overhauled the tracker system (#172)
* Overhauled the tracker system Separated the logging and saving capabilities Changed creation to be consistent and initializing behavior to be defined by a class initializer instead of in the training script Added class separation between different types of loaders and savers to make the system more verbose * Changed the saver system to only save the checkpoint once * Added better error handling for saving checkpoints * Fixed an error where wandb would error when passed arbitrary kwargs * Fixed variable naming issues for improved saver Added more logging during long pauses * Fixed which methods need to be dummy to immediatly return Added the ability to set whether you find unused parameters * Added more logging for when a wandb loader fails
This commit is contained in:
@@ -15,6 +15,7 @@ from dalle2_pytorch.dalle2_pytorch import (
|
||||
DiffusionPriorNetwork,
|
||||
XClipAdapter
|
||||
)
|
||||
from dalle2_pytorch.trackers import Tracker, create_loader, create_logger, create_saver
|
||||
|
||||
# helper functions
|
||||
|
||||
@@ -44,13 +45,66 @@ class TrainSplitConfig(BaseModel):
|
||||
raise ValueError(f'{fields.keys()} must sum to 1.0. Found: {actual_sum}')
|
||||
return fields
|
||||
|
||||
class TrackerLogConfig(BaseModel):
|
||||
log_type: str = 'console'
|
||||
verbose: bool = False
|
||||
|
||||
class Config:
|
||||
# Each individual log type has it's own arguments that will be passed through the config
|
||||
extra = "allow"
|
||||
|
||||
def create(self, data_path: str):
|
||||
kwargs = self.dict()
|
||||
return create_logger(self.log_type, data_path, **kwargs)
|
||||
|
||||
class TrackerLoadConfig(BaseModel):
|
||||
load_from: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def create(self, data_path: str):
|
||||
kwargs = self.dict()
|
||||
if self.load_from is None:
|
||||
return None
|
||||
return create_loader(self.load_from, data_path, **kwargs)
|
||||
|
||||
class TrackerSaveConfig(BaseModel):
|
||||
save_to: str = 'local'
|
||||
save_all: bool = False
|
||||
save_latest: bool = True
|
||||
save_best: bool = True
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def create(self, data_path: str):
|
||||
kwargs = self.dict()
|
||||
return create_saver(self.save_to, data_path, **kwargs)
|
||||
|
||||
class TrackerConfig(BaseModel):
|
||||
tracker_type: str = 'console' # Decoder currently supports console and wandb
|
||||
data_path: str = './models' # The path where files will be saved locally
|
||||
init_config: Dict[str, Any] = None
|
||||
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
|
||||
wandb_project: str = ''
|
||||
verbose: bool = False # Whether to print console logging for non-console trackers
|
||||
data_path: str = '.tracker_data'
|
||||
overwrite_data_path: bool = False
|
||||
log: TrackerLogConfig
|
||||
load: Optional[TrackerLoadConfig]
|
||||
save: Union[List[TrackerSaveConfig], TrackerSaveConfig]
|
||||
|
||||
def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
|
||||
tracker = Tracker(self.data_path, dummy_mode=dummy_mode, overwrite_data_path=self.overwrite_data_path)
|
||||
# Add the logger
|
||||
tracker.add_logger(self.log.create(self.data_path))
|
||||
# Add the loader
|
||||
if self.load is not None:
|
||||
tracker.add_loader(self.load.create(self.data_path))
|
||||
# Add the saver or savers
|
||||
if isinstance(self.save, list):
|
||||
for save_config in self.save:
|
||||
tracker.add_saver(save_config.create(self.data_path))
|
||||
else:
|
||||
tracker.add_saver(self.save.create(self.data_path))
|
||||
# Initialize all the components and verify that all data is valid
|
||||
tracker.init(full_config, extra_config)
|
||||
return tracker
|
||||
|
||||
# diffusion prior pydantic classes
|
||||
|
||||
@@ -238,6 +292,7 @@ class DecoderTrainConfig(BaseModel):
|
||||
epochs: int = 20
|
||||
lr: SingularOrIterable(float) = 1e-4
|
||||
wd: SingularOrIterable(float) = 0.01
|
||||
find_unused_parameters: bool = True
|
||||
max_grad_norm: SingularOrIterable(float) = 0.5
|
||||
save_every_n_samples: int = 100000
|
||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||
@@ -247,9 +302,6 @@ class DecoderTrainConfig(BaseModel):
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.999
|
||||
amp: bool = False
|
||||
save_all: bool = False # Whether to preserve all checkpoints
|
||||
save_latest: bool = True # Whether to always save the latest checkpoint
|
||||
save_best: bool = True # Whether to save the best checkpoint
|
||||
unet_training_mask: ListOrTuple(bool) = None # If None, use all unets
|
||||
|
||||
class DecoderEvaluateConfig(BaseModel):
|
||||
@@ -271,7 +323,6 @@ class TrainDecoderConfig(BaseModel):
|
||||
train: DecoderTrainConfig
|
||||
evaluate: DecoderEvaluateConfig
|
||||
tracker: TrackerConfig
|
||||
load: DecoderLoadConfig
|
||||
seed: int = 0
|
||||
|
||||
@classmethod
|
||||
@@ -294,17 +345,17 @@ class TrainDecoderConfig(BaseModel):
|
||||
img_emb_url = data_config.img_embeddings_url
|
||||
text_emb_url = data_config.text_embeddings_url
|
||||
|
||||
if using_text_embeddings:
|
||||
if using_text_encodings:
|
||||
# Then we need some way to get the embeddings
|
||||
assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
|
||||
|
||||
if using_clip:
|
||||
if using_text_embeddings:
|
||||
if using_text_encodings:
|
||||
assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
|
||||
else:
|
||||
assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
|
||||
|
||||
if text_emb_url:
|
||||
assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
|
||||
assert using_text_encodings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
|
||||
|
||||
return values
|
||||
|
||||
Reference in New Issue
Block a user