mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Overhauled the tracker system (#172)
* Overhauled the tracker system Separated the logging and saving capabilities Changed creation to be consistent and initializing behavior to be defined by a class initializer instead of in the training script Added class separation between different types of loaders and savers to make the system more verbose * Changed the saver system to only save the checkpoint once * Added better error handling for saving checkpoints * Fixed an error where wandb would error when passed arbitrary kwargs * Fixed variable naming issues for improved saver Added more logging during long pauses * Fixed which methods need to be dummy to immediatly return Added the ability to set whether you find unused parameters * Added more logging for when a wandb loader fails
This commit is contained in:
106
train_decoder.py
106
train_decoder.py
@@ -1,11 +1,12 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from dalle2_pytorch.trainer import DecoderTrainer
|
||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker, DummyTracker
|
||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||
from dalle2_pytorch.trackers import Tracker
|
||||
from dalle2_pytorch.train_configs import DecoderConfig, TrainDecoderConfig
|
||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, resize_image_to
|
||||
from clip import tokenize
|
||||
|
||||
import torchvision
|
||||
@@ -239,42 +240,33 @@ def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=Fa
|
||||
metrics[metric_name] = metrics_tensor[i].item()
|
||||
return metrics
|
||||
|
||||
def save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, relative_paths):
|
||||
def save_trainer(tracker: Tracker, trainer: DecoderTrainer, epoch: int, sample: int, next_task: str, validation_losses: List[float], samples_seen: int, is_latest=True, is_best=False):
|
||||
"""
|
||||
Logs the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
if isinstance(relative_paths, str):
|
||||
relative_paths = [relative_paths]
|
||||
for relative_path in relative_paths:
|
||||
local_path = str(tracker.data_path / relative_path)
|
||||
trainer.save(local_path, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses)
|
||||
tracker.save_file(local_path)
|
||||
tracker.save(trainer, is_best=is_best, is_latest=is_latest, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses, samples_seen=samples_seen)
|
||||
|
||||
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
||||
def recall_trainer(tracker: Tracker, trainer: DecoderTrainer):
|
||||
"""
|
||||
Loads the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
trainer.accelerator.print(print_ribbon(f"Loading model from {recall_source}"))
|
||||
local_filepath = tracker.recall_file(recall_source, **load_config)
|
||||
state_dict = trainer.load(local_filepath)
|
||||
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0)
|
||||
trainer.accelerator.print(print_ribbon(f"Loading model from {type(tracker.loader).__name__}"))
|
||||
state_dict = tracker.recall()
|
||||
trainer.load_state_dict(state_dict, only_model=False, strict=True)
|
||||
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0), state_dict.get("samples_seen", 0)
|
||||
|
||||
def train(
|
||||
dataloaders,
|
||||
decoder,
|
||||
accelerator,
|
||||
tracker,
|
||||
decoder: Decoder,
|
||||
accelerator: Accelerator,
|
||||
tracker: Tracker,
|
||||
inference_device,
|
||||
load_config=None,
|
||||
evaluate_config=None,
|
||||
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
||||
validation_samples = None,
|
||||
epochs = 20,
|
||||
n_sample_images = 5,
|
||||
save_every_n_samples = 100000,
|
||||
save_all=False,
|
||||
save_latest=True,
|
||||
save_best=True,
|
||||
unet_training_mask=None,
|
||||
condition_on_text_encodings=False,
|
||||
**kwargs
|
||||
@@ -299,13 +291,13 @@ def train(
|
||||
val_sample = 0
|
||||
step = lambda: int(trainer.step.item())
|
||||
|
||||
if exists(load_config) and exists(load_config.source):
|
||||
start_epoch, validation_losses, next_task, recalled_sample = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config.dict())
|
||||
if tracker.loader is not None:
|
||||
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
||||
if next_task == 'train':
|
||||
sample = recalled_sample
|
||||
if next_task == 'val':
|
||||
val_sample = recalled_sample
|
||||
accelerator.print(f"Loaded model from {load_config.source} on epoch {start_epoch} with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
|
||||
accelerator.print(f"Loaded model from {type(tracker.loader).__name__} on epoch {start_epoch} having seen {samples_seen} samples with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
|
||||
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
|
||||
trainer.to(device=inference_device)
|
||||
|
||||
@@ -399,19 +391,14 @@ def train(
|
||||
}
|
||||
|
||||
if is_master:
|
||||
tracker.log(log_data, step=step(), verbose=True)
|
||||
tracker.log(log_data, step=step())
|
||||
|
||||
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
||||
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
|
||||
print("Saving snapshot")
|
||||
last_snapshot = sample
|
||||
# We need to know where the model should be saved
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
if save_all:
|
||||
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step()}.pth")
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
|
||||
if exists(n_sample_images) and n_sample_images > 0:
|
||||
trainer.eval()
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
||||
@@ -486,7 +473,7 @@ def train(
|
||||
if is_master:
|
||||
unet_average_val_loss = all_average_val_losses.mean(dim=0)
|
||||
val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
|
||||
tracker.log(val_loss_map, step=step(), verbose=True)
|
||||
tracker.log(val_loss_map, step=step())
|
||||
next_task = 'eval'
|
||||
|
||||
if next_task == 'eval':
|
||||
@@ -494,7 +481,7 @@ def train(
|
||||
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
|
||||
if is_master:
|
||||
tracker.log(evaluation, step=step(), verbose=True)
|
||||
tracker.log(evaluation, step=step())
|
||||
next_task = 'sample'
|
||||
val_sample = 0
|
||||
|
||||
@@ -509,22 +496,16 @@ def train(
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
||||
|
||||
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
||||
# Get the same paths
|
||||
save_paths = []
|
||||
if save_latest:
|
||||
save_paths.append("latest.pth")
|
||||
is_best = False
|
||||
if all_average_val_losses is not None:
|
||||
average_loss = all_average_val_losses.mean(dim=0).item()
|
||||
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
|
||||
save_paths.append("best.pth")
|
||||
if len(validation_losses) == 0 or average_loss < min(validation_losses):
|
||||
is_best = True
|
||||
validation_losses.append(average_loss)
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, save_paths)
|
||||
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen, is_best=is_best)
|
||||
next_task = 'train'
|
||||
|
||||
def create_tracker(accelerator, config, config_path, tracker_type=None, data_path=None):
|
||||
"""
|
||||
Creates a tracker of the specified type and initializes special features based on the full config
|
||||
"""
|
||||
def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_path: str, dummy: bool = False) -> Tracker:
|
||||
tracker_config = config.tracker
|
||||
accelerator_config = {
|
||||
"Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
|
||||
@@ -532,40 +513,16 @@ def create_tracker(accelerator, config, config_path, tracker_type=None, data_pat
|
||||
"NumProcesses": accelerator.num_processes,
|
||||
"MixedPrecision": accelerator.mixed_precision
|
||||
}
|
||||
init_config = { "config": {**config.dict(), **accelerator_config} }
|
||||
data_path = data_path or tracker_config.data_path
|
||||
tracker_type = tracker_type or tracker_config.tracker_type
|
||||
|
||||
if tracker_type == "dummy":
|
||||
tracker = DummyTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
elif tracker_type == "console":
|
||||
tracker = ConsoleTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
elif tracker_type == "wandb":
|
||||
# We need to initialize the resume state here
|
||||
load_config = config.load
|
||||
if load_config.source == "wandb" and load_config.resume:
|
||||
# Then we are resuming the run load_config["run_path"]
|
||||
run_id = load_config.run_path.split("/")[-1]
|
||||
init_config["id"] = run_id
|
||||
init_config["resume"] = "must"
|
||||
|
||||
init_config["entity"] = tracker_config.wandb_entity
|
||||
init_config["project"] = tracker_config.wandb_project
|
||||
tracker = WandbTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
tracker.save_file(str(config_path.absolute()), str(config_path.parent.absolute()))
|
||||
else:
|
||||
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
|
||||
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
|
||||
tracker.save_config(config_path, config_name='decoder_config.json')
|
||||
return tracker
|
||||
|
||||
def initialize_training(config, config_path):
|
||||
def initialize_training(config: TrainDecoderConfig, config_path):
|
||||
# Make sure if we are not loading, distributed models are initialized to the same values
|
||||
torch.manual_seed(config.seed)
|
||||
|
||||
# Set up accelerator for configurable distributed training
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
|
||||
# Set up data
|
||||
@@ -592,7 +549,7 @@ def initialize_training(config, config_path):
|
||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||
|
||||
# Create and initialize the tracker if we are the master
|
||||
tracker = create_tracker(accelerator, config, config_path) if rank == 0 else create_tracker(accelerator, config, config_path, tracker_type="dummy")
|
||||
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
|
||||
|
||||
has_img_embeddings = config.data.img_embeddings_url is not None
|
||||
has_text_embeddings = config.data.text_embeddings_url is not None
|
||||
@@ -622,7 +579,6 @@ def initialize_training(config, config_path):
|
||||
train(dataloaders, decoder, accelerator,
|
||||
tracker=tracker,
|
||||
inference_device=accelerator.device,
|
||||
load_config=config.load,
|
||||
evaluate_config=config.evaluate,
|
||||
condition_on_text_encodings=conditioning_on_text,
|
||||
**config.train.dict(),
|
||||
|
||||
Reference in New Issue
Block a user