Overhauled the tracker system (#172)

* Overhauled the tracker system
Separated the logging and saving capabilities
Changed creation to be consistent and initializing behavior to be defined by a class initializer instead of in the training script
Added class separation between different types of loaders and savers to make the system more verbose

* Changed the saver system to only save the checkpoint once

* Added better error handling for saving checkpoints

* Fixed an error where wandb would error when passed arbitrary kwargs

* Fixed variable naming issues for improved saver
Added more logging during long pauses

* Fixed which methods need to be dummy to immediatly return
Added the ability to set whether you find unused parameters

* Added more logging for when a wandb loader fails
This commit is contained in:
Aidan Dempster
2022-07-01 12:39:40 -04:00
committed by GitHub
parent 7b0edf9e42
commit 27b0f7ca0d
7 changed files with 662 additions and 212 deletions

View File

@@ -1,11 +1,12 @@
from pathlib import Path
from typing import List
from dalle2_pytorch.trainer import DecoderTrainer
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.trackers import Tracker
from dalle2_pytorch.train_configs import DecoderConfig, TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.dalle2_pytorch import resize_image_to
from dalle2_pytorch.dalle2_pytorch import Decoder, resize_image_to
from clip import tokenize
import torchvision
@@ -239,42 +240,33 @@ def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=Fa
metrics[metric_name] = metrics_tensor[i].item()
return metrics
def save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, relative_paths):
def save_trainer(tracker: Tracker, trainer: DecoderTrainer, epoch: int, sample: int, next_task: str, validation_losses: List[float], samples_seen: int, is_latest=True, is_best=False):
"""
Logs the model with an appropriate method depending on the tracker
"""
if isinstance(relative_paths, str):
relative_paths = [relative_paths]
for relative_path in relative_paths:
local_path = str(tracker.data_path / relative_path)
trainer.save(local_path, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses)
tracker.save_file(local_path)
tracker.save(trainer, is_best=is_best, is_latest=is_latest, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses, samples_seen=samples_seen)
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
def recall_trainer(tracker: Tracker, trainer: DecoderTrainer):
"""
Loads the model with an appropriate method depending on the tracker
"""
trainer.accelerator.print(print_ribbon(f"Loading model from {recall_source}"))
local_filepath = tracker.recall_file(recall_source, **load_config)
state_dict = trainer.load(local_filepath)
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0)
trainer.accelerator.print(print_ribbon(f"Loading model from {type(tracker.loader).__name__}"))
state_dict = tracker.recall()
trainer.load_state_dict(state_dict, only_model=False, strict=True)
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0), state_dict.get("samples_seen", 0)
def train(
dataloaders,
decoder,
accelerator,
tracker,
decoder: Decoder,
accelerator: Accelerator,
tracker: Tracker,
inference_device,
load_config=None,
evaluate_config=None,
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
validation_samples = None,
epochs = 20,
n_sample_images = 5,
save_every_n_samples = 100000,
save_all=False,
save_latest=True,
save_best=True,
unet_training_mask=None,
condition_on_text_encodings=False,
**kwargs
@@ -299,13 +291,13 @@ def train(
val_sample = 0
step = lambda: int(trainer.step.item())
if exists(load_config) and exists(load_config.source):
start_epoch, validation_losses, next_task, recalled_sample = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config.dict())
if tracker.loader is not None:
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
if next_task == 'train':
sample = recalled_sample
if next_task == 'val':
val_sample = recalled_sample
accelerator.print(f"Loaded model from {load_config.source} on epoch {start_epoch} with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
accelerator.print(f"Loaded model from {type(tracker.loader).__name__} on epoch {start_epoch} having seen {samples_seen} samples with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
trainer.to(device=inference_device)
@@ -399,19 +391,14 @@ def train(
}
if is_master:
tracker.log(log_data, step=step(), verbose=True)
tracker.log(log_data, step=step())
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
print("Saving snapshot")
last_snapshot = sample
# We need to know where the model should be saved
save_paths = []
if save_latest:
save_paths.append("latest.pth")
if save_all:
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step()}.pth")
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
@@ -486,7 +473,7 @@ def train(
if is_master:
unet_average_val_loss = all_average_val_losses.mean(dim=0)
val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
tracker.log(val_loss_map, step=step(), verbose=True)
tracker.log(val_loss_map, step=step())
next_task = 'eval'
if next_task == 'eval':
@@ -494,7 +481,7 @@ def train(
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
if is_master:
tracker.log(evaluation, step=step(), verbose=True)
tracker.log(evaluation, step=step())
next_task = 'sample'
val_sample = 0
@@ -509,22 +496,16 @@ def train(
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
# Get the same paths
save_paths = []
if save_latest:
save_paths.append("latest.pth")
is_best = False
if all_average_val_losses is not None:
average_loss = all_average_val_losses.mean(dim=0).item()
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
save_paths.append("best.pth")
if len(validation_losses) == 0 or average_loss < min(validation_losses):
is_best = True
validation_losses.append(average_loss)
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen, is_best=is_best)
next_task = 'train'
def create_tracker(accelerator, config, config_path, tracker_type=None, data_path=None):
"""
Creates a tracker of the specified type and initializes special features based on the full config
"""
def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_path: str, dummy: bool = False) -> Tracker:
tracker_config = config.tracker
accelerator_config = {
"Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
@@ -532,40 +513,16 @@ def create_tracker(accelerator, config, config_path, tracker_type=None, data_pat
"NumProcesses": accelerator.num_processes,
"MixedPrecision": accelerator.mixed_precision
}
init_config = { "config": {**config.dict(), **accelerator_config} }
data_path = data_path or tracker_config.data_path
tracker_type = tracker_type or tracker_config.tracker_type
if tracker_type == "dummy":
tracker = DummyTracker(data_path)
tracker.init(**init_config)
elif tracker_type == "console":
tracker = ConsoleTracker(data_path)
tracker.init(**init_config)
elif tracker_type == "wandb":
# We need to initialize the resume state here
load_config = config.load
if load_config.source == "wandb" and load_config.resume:
# Then we are resuming the run load_config["run_path"]
run_id = load_config.run_path.split("/")[-1]
init_config["id"] = run_id
init_config["resume"] = "must"
init_config["entity"] = tracker_config.wandb_entity
init_config["project"] = tracker_config.wandb_project
tracker = WandbTracker(data_path)
tracker.init(**init_config)
tracker.save_file(str(config_path.absolute()), str(config_path.parent.absolute()))
else:
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
tracker.save_config(config_path, config_name='decoder_config.json')
return tracker
def initialize_training(config, config_path):
def initialize_training(config: TrainDecoderConfig, config_path):
# Make sure if we are not loading, distributed models are initialized to the same values
torch.manual_seed(config.seed)
# Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
# Set up data
@@ -592,7 +549,7 @@ def initialize_training(config, config_path):
num_parameters = sum(p.numel() for p in decoder.parameters())
# Create and initialize the tracker if we are the master
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
has_img_embeddings = config.data.img_embeddings_url is not None
has_text_embeddings = config.data.text_embeddings_url is not None
@@ -622,7 +579,6 @@ def initialize_training(config, config_path):
train(dataloaders, decoder, accelerator,
tracker=tracker,
inference_device=accelerator.device,
load_config=config.load,
evaluate_config=config.evaluate,
condition_on_text_encodings=conditioning_on_text,
**config.train.dict(),