From f9423d308b6f36e51152c2c45045ff4ebb308287 Mon Sep 17 00:00:00 2001 From: zion <51308183+nousr@users.noreply.github.com> Date: Wed, 20 Jul 2022 18:04:26 -0700 Subject: [PATCH] Prior updates (#211) * update configs for prior add prior warmup to config update example prior config * update prior trainer & script add deepspeed amp & warmup adopt full accelerator support reload at sample point finish epoch resume code * update tracker save method for prior * helper functions for prior_loader --- configs/train_prior_config.example.json | 65 +- dalle2_pytorch/dataloaders/prior_loader.py | 9 + dalle2_pytorch/trackers.py | 8 +- dalle2_pytorch/train_configs.py | 32 +- dalle2_pytorch/trainer.py | 163 ++--- train_diffusion_prior.py | 751 +++++++++++++++------ 6 files changed, 676 insertions(+), 352 deletions(-) diff --git a/configs/train_prior_config.example.json b/configs/train_prior_config.example.json index 151ca28..405b853 100644 --- a/configs/train_prior_config.example.json +++ b/configs/train_prior_config.example.json @@ -1,18 +1,14 @@ { "prior": { "clip": { - "make": "x-clip", - "model": "ViT-L/14", - "base_model_kwargs": { - "dim_text": 768, - "dim_image": 768, - "dim_latent": 768 - } + "make": "openai", + "model": "ViT-L/14" }, "net": { "dim": 768, "depth": 12, "num_timesteps": 1000, + "max_text_len": 77, "num_time_embeds": 1, "num_image_embeds": 1, "num_text_embeds": 1, @@ -20,8 +16,8 @@ "heads": 12, "ff_mult": 4, "norm_out": true, - "attn_dropout": 0.0, - "ff_dropout": 0.0, + "attn_dropout": 0.05, + "ff_dropout": 0.05, "final_proj": true, "normformer": true, "rotary_emb": true @@ -30,6 +26,7 @@ "image_size": 224, "image_channels": 3, "timesteps": 1000, + "sample_timesteps": 64, "cond_drop_prob": 0.1, "loss_type": "l2", "predict_x_start": true, @@ -37,34 +34,48 @@ "condition_on_text_encodings": true }, "data": { - "image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/", - "text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/", - "meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/", - "batch_size": 256, + "batch_size": 128, + "num_data_points": 100000, + "eval_every_seconds": 1600, + "image_url": "", + "meta_url": "", "splits": { - "train": 0.9, - "val": 1e-7, - "test": 0.0999999 + "train": 0.8, + "val": 0.1, + "test": 0.1 } }, "train": { - "epochs": 1, + "epochs": 5, "lr": 1.1e-4, "wd": 6.02e-2, "max_grad_norm": 0.5, "use_ema": true, + "ema_beta": 0.9999, + "ema_update_after_step": 50, + "warmup_steps": 50, "amp": false, - "save_every": 10000 - }, - "load": { - "source": null, - "resume": false + "save_every_seconds": 3600, + "eval_timesteps": [64, 1000], + "random_seed": 84513 }, "tracker": { - "tracker_type": "wandb", - "data_path": "./prior_checkpoints", - "wandb_entity": "laion", - "wandb_project": "diffusion-prior", - "verbose": true + "data_path": ".prior", + "overwrite_data_path": true, + "log": { + "log_type": "wandb", + "wandb_entity": "", + "wandb_project": "prior_debugging", + "wandb_resume": false, + "verbose": true + }, + "save": [ + { + "save_to": "local", + "save_type": "checkpoint", + "save_latest_to": ".prior/latest_checkpoint.pth", + "save_best_to": ".prior/best_checkpoint.pth" + } + ] } } diff --git a/dalle2_pytorch/dataloaders/prior_loader.py b/dalle2_pytorch/dataloaders/prior_loader.py index cbbfc57..f612653 100644 --- a/dalle2_pytorch/dataloaders/prior_loader.py +++ b/dalle2_pytorch/dataloaders/prior_loader.py @@ -67,6 +67,15 @@ class PriorEmbeddingDataset(IterableDataset): def __str__(self): return f"" + def set_start(self, start): + """ + Adjust the starting point within the reader, useful for resuming an epoch + """ + self.start = start + + def get_start(self): + return self.start + def get_sample(self): """ pre-proocess data from either reader into a common format diff --git a/dalle2_pytorch/trackers.py b/dalle2_pytorch/trackers.py index 13834c9..6f86ede 100644 --- a/dalle2_pytorch/trackers.py +++ b/dalle2_pytorch/trackers.py @@ -528,12 +528,8 @@ class Tracker: elif save_type == 'model': if isinstance(trainer, DiffusionPriorTrainer): prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior - prior: DiffusionPrior = trainer.unwrap_model(prior) - # Remove CLIP if it is part of the model - original_clip = prior.clip - prior.clip = None - model_state_dict = prior.state_dict() - prior.clip = original_clip + state_dict = trainer.accelerator.unwrap_model(prior).state_dict() + torch.save(state_dict, file_path) elif isinstance(trainer, DecoderTrainer): decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder) # Remove CLIP if it is part of the model diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 4a0c003..307f011 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -145,6 +145,9 @@ class DiffusionPriorNetworkConfig(BaseModel): normformer: bool = False rotary_emb: bool = True + class Config: + extra = "allow" + def create(self): kwargs = self.dict() return DiffusionPriorNetwork(**kwargs) @@ -187,23 +190,26 @@ class DiffusionPriorTrainConfig(BaseModel): use_ema: bool = True ema_beta: float = 0.99 amp: bool = False - save_every: int = 10000 # what steps to save on + warmup_steps: int = None # number of warmup steps + save_every_seconds: int = 3600 # how often to save + eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with + best_validation_loss: float = 1e9 # the current best valudation loss observed + current_epoch: int = 0 # the current epoch + num_samples_seen: int = 0 # the current number of samples seen + random_seed: int = 0 # manual seed for torch class DiffusionPriorDataConfig(BaseModel): - image_url: str # path to embeddings folder - meta_url: str # path to metadata (captions) for images - splits: TrainSplitConfig - batch_size: int = 64 - -class DiffusionPriorLoadConfig(BaseModel): - source: str = None - resume: bool = False + image_url: str # path to embeddings folder + meta_url: str # path to metadata (captions) for images + splits: TrainSplitConfig # define train, validation, test splits for your dataset + batch_size: int # per-gpu batch size used to train the model + num_data_points: int = 25e7 # total number of datapoints to train on + eval_every_seconds: int = 3600 # validation statistics will be performed this often class TrainDiffusionPriorConfig(BaseModel): prior: DiffusionPriorConfig data: DiffusionPriorDataConfig train: DiffusionPriorTrainConfig - load: DiffusionPriorLoadConfig tracker: TrackerConfig @classmethod @@ -323,12 +329,6 @@ class DecoderEvaluateConfig(BaseModel): KID: Dict[str, Any] = None LPIPS: Dict[str, Any] = None -class DecoderLoadConfig(BaseModel): - source: str = None # Supports file and wandb - run_path: str = '' # Used only if source is wandb - file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb. - resume: bool = False # If using wandb, whether to resume the run - class TrainDecoderConfig(BaseModel): decoder: DecoderConfig data: DecoderDataConfig diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index e0b3b26..41ce286 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -174,27 +174,21 @@ class DiffusionPriorTrainer(nn.Module): def __init__( self, diffusion_prior, + accelerator, use_ema = True, lr = 3e-4, wd = 1e-2, eps = 1e-6, max_grad_norm = None, - amp = False, group_wd_params = True, - device = None, - accelerator = None, - verbose = True, + warmup_steps = 1, **kwargs ): super().__init__() assert isinstance(diffusion_prior, DiffusionPrior) - assert not exists(accelerator) or isinstance(accelerator, Accelerator) + assert isinstance(accelerator, Accelerator) ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) - # verbosity - - self.verbose = verbose - # assign some helpful member vars self.accelerator = accelerator @@ -202,23 +196,31 @@ class DiffusionPriorTrainer(nn.Module): # setting the device - if not exists(accelerator) and not exists(device): - diffusion_prior_device = next(diffusion_prior.parameters()).device - self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}') - self.device = diffusion_prior_device - else: - self.device = accelerator.device if exists(accelerator) else device - diffusion_prior.to(self.device) + self.device = accelerator.device + diffusion_prior.to(self.device) # save model self.diffusion_prior = diffusion_prior - # optimizer and mixed precision stuff + # mixed precision checks - self.amp = amp + if ( + exists(self.accelerator) + and self.accelerator.distributed_type == DistributedType.DEEPSPEED + and self.diffusion_prior.clip is not None + ): + # Then we need to make sure clip is using the correct precision or else deepspeed will error + cast_type_map = { + "fp16": torch.half, + "bf16": torch.bfloat16, + "no": torch.float + } + precision_type = cast_type_map[accelerator.mixed_precision] + assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip" + self.diffusion_prior.clip.to(precision_type) - self.scaler = GradScaler(enabled = amp) + # optimizer stuff self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params) @@ -227,17 +229,21 @@ class DiffusionPriorTrainer(nn.Module): **self.optim_kwargs, **kwargs ) + + self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0) + + self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None # distribute the model if using HFA - if exists(self.accelerator): - self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer) + + self.diffusion_prior, self.optimizer, self.scheduler = self.accelerator.prepare(self.diffusion_prior, self.optimizer, self.scheduler) # exponential moving average stuff self.use_ema = use_ema if self.use_ema: - self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs) + self.ema_diffusion_prior = EMA(self.accelerator.unwrap_model(self.diffusion_prior), **ema_kwargs) # gradient clipping if needed @@ -247,67 +253,24 @@ class DiffusionPriorTrainer(nn.Module): self.register_buffer('step', torch.tensor([0], device = self.device)) - # accelerator wrappers - - def print(self, msg): - if not self.verbose: - return - - if exists(self.accelerator): - self.accelerator.print(msg) - else: - print(msg) - - def unwrap_model(self, model): - if exists(self.accelerator): - return self.accelerator.unwrap_model(model) - else: - return model - - def wait_for_everyone(self): - if exists(self.accelerator): - self.accelerator.wait_for_everyone() - - def is_main_process(self): - if exists(self.accelerator): - return self.accelerator.is_main_process - else: - return True - - def clip_grad_norm_(self, *args): - if exists(self.accelerator): - return self.accelerator.clip_grad_norm_(*args) - else: - return torch.nn.utils.clip_grad_norm_(*args) - - def backprop(self, x): - if exists(self.accelerator): - self.accelerator.backward(x) - else: - try: - x.backward() - except Exception as e: - self.print(f"Caught error in backprop call: {e}") - # utility def save(self, path, overwrite = True, **kwargs): - # ensure we sync gradients before continuing - self.wait_for_everyone() # only save on the main process - if self.is_main_process(): - self.print(f"Saving checkpoint at step: {self.step.item()}") + if self.accelerator.is_main_process: + print(f"Saving checkpoint at step: {self.step.item()}") path = Path(path) assert not (path.exists() and not overwrite) path.parent.mkdir(parents = True, exist_ok = True) + # FIXME: LambdaLR can't be saved due to pickling issues save_obj = dict( - scaler = self.scaler.state_dict(), optimizer = self.optimizer.state_dict(), - model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable + warmup_scheduler = self.warmup_scheduler, + model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(), version = version.parse(__version__), - step = self.step.item(), + step = self.step, **kwargs ) @@ -320,14 +283,14 @@ class DiffusionPriorTrainer(nn.Module): torch.save(save_obj, str(path)) - def load(self, path, overwrite_lr = True, strict = True): + def load(self, path_or_state, overwrite_lr = True, strict = True): """ Load a checkpoint of a diffusion prior trainer. Will load the entire trainer, including the optimizer and EMA. Params: - - path (str): a path to the DiffusionPriorTrainer checkpoint file + - path_or_state (str | torch): a path to the DiffusionPriorTrainer checkpoint file - overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer - strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match @@ -336,56 +299,56 @@ class DiffusionPriorTrainer(nn.Module): """ # all processes need to load checkpoint. no restriction here - path = Path(path) - assert path.exists() + if isinstance(path_or_state, str): + path = Path(path) + assert path.exists() + loaded_obj = torch.load(str(path), map_location=self.device) - loaded_obj = torch.load(str(path), map_location=self.device) + elif isinstance(path_or_state, dict): + loaded_obj = path_or_state if version.parse(__version__) != loaded_obj['version']: print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}') # unwrap the model when loading from checkpoint - self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict) - self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) - - self.scaler.load_state_dict(loaded_obj['scaler']) + self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict) + self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device)) self.optimizer.load_state_dict(loaded_obj['optimizer']) + # set warmupstep + if exists(self.warmup_scheduler): + self.warmup_scheduler.last_step = self.step.item() + + # ensure new lr is used if different from old one if overwrite_lr: new_lr = self.optim_kwargs["lr"] - self.print(f"Overriding LR to be {new_lr}") - for group in self.optimizer.param_groups: - group["lr"] = new_lr + group["lr"] = new_lr if group["lr"] > 0.0 else 0.0 if self.use_ema: assert 'ema' in loaded_obj self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict) - # below not be necessary, but I had a suspicion that this wasn't being loaded correctly + # below might not be necessary, but I had a suspicion that this wasn't being loaded correctly self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"]) - # sync and inform - self.wait_for_everyone() - self.print(f"Loaded model") - return loaded_obj # model functionality def update(self): - # only continue with updates until all ranks finish - self.wait_for_everyone() if exists(self.max_grad_norm): - self.scaler.unscale_(self.optimizer) - # utilize HFA clipping where applicable - self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm) - - self.scaler.step(self.optimizer) - self.scaler.update() + self.accelerator.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm) + + self.optimizer.step() self.optimizer.zero_grad() + # accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy" + if not self.accelerator.optimizer_step_was_skipped: + with self.warmup_scheduler.dampening(): + self.scheduler.step() + if self.use_ema: self.ema_diffusion_prior.update() @@ -414,7 +377,7 @@ class DiffusionPriorTrainer(nn.Module): @cast_torch_tensor @prior_sample_in_chunks def embed_text(self, *args, **kwargs): - return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs) + return self.accelerator.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs) @cast_torch_tensor def forward( @@ -426,16 +389,14 @@ class DiffusionPriorTrainer(nn.Module): total_loss = 0. for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): - with autocast(enabled = self.amp): + with self.accelerator.autocast(): loss = self.diffusion_prior(*chunked_args, **chunked_kwargs) loss = loss * chunk_size_frac total_loss += loss.item() - # backprop with accelerate if applicable - if self.training: - self.backprop(self.scaler.scale(loss)) + self.accelerator.backward(loss) return total_loss diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 454dc79..0887956 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -1,31 +1,23 @@ -# TODO: add start, num_data_points, eval_every and group to config -# TODO: switch back to repo's wandb - -START = 0 -NUM_DATA_POINTS = 250e6 -EVAL_EVERY = 1000 -GROUP = "distributed" - -import os import click -import wandb - import torch + from torch import nn -from torch.utils.data import DataLoader - -import numpy as np - +from typing import List from accelerate import Accelerator +from accelerate.utils import set_seed +from torch.utils.data import DataLoader +from embedding_reader import EmbeddingReader +from accelerate.utils import dataclasses as accelerate_dataclasses -from dalle2_pytorch.dataloaders import get_reader, make_splits from dalle2_pytorch.utils import Timer +from dalle2_pytorch.trackers import Tracker +from dalle2_pytorch import DiffusionPriorTrainer +from dalle2_pytorch.dataloaders import get_reader, make_splits from dalle2_pytorch.train_configs import ( + DiffusionPriorConfig, DiffusionPriorTrainConfig, TrainDiffusionPriorConfig, ) -from dalle2_pytorch.trackers import BaseTracker, WandbTracker -from dalle2_pytorch import DiffusionPriorTrainer # helpers @@ -38,8 +30,19 @@ def exists(val): return val is not None +def all_between(values: list, lower_bound, upper_bound): + for value in values: + if value < lower_bound or value > upper_bound: + return False + + return True + + def make_model( - prior_config, train_config, device: str = None, accelerator: Accelerator = None + prior_config: DiffusionPriorConfig, + train_config: DiffusionPriorTrainConfig, + device: str = None, + accelerator: Accelerator = None, ): # create model from config diffusion_prior = prior_config.create() @@ -54,71 +57,214 @@ def make_model( use_ema=train_config.use_ema, device=device, accelerator=accelerator, + warmup_steps=train_config.warmup_steps, ) return trainer +def create_tracker( + accelerator: Accelerator, + config: TrainDiffusionPriorConfig, + config_path: str, + dummy: bool = False, +) -> Tracker: + tracker_config = config.tracker + + accelerator_config = { + "Distributed": accelerator.distributed_type + != accelerate_dataclasses.DistributedType.NO, + "DistributedType": accelerator.distributed_type, + "NumProcesses": accelerator.num_processes, + "MixedPrecision": accelerator.mixed_precision, + } + + tracker: Tracker = tracker_config.create( + config, accelerator_config, dummy_mode=dummy + ) + + tracker.save_config(config_path, config_name="prior_config.json") + + return tracker + + +def pad_gather_reduce(trainer: DiffusionPriorTrainer, x, method="mean"): + """ + pad a value or tensor across all processes and gather + + params: + - trainer: a trainer that carries an accelerator object + - x: a number or torch tensor to reduce + - method: "mean", "sum", "max", "min" + + return: + - the average tensor after maskin out 0's + - None if the gather resulted in an empty tensor + """ + + assert method in [ + "mean", + "sum", + "max", + "min", + ], "This function has limited capabilities [sum, mean, max, min]" + assert type(x) is not None, "Cannot reduce a None type object" + + # wait for everyone to arrive here before gathering + + if type(x) is not torch.Tensor: + x = torch.tensor([x]) + + # verify that the tensor is on the proper device + x = x.to(trainer.device) + + # pad across processes + padded_x = trainer.accelerator.pad_across_processes(x, dim=0) + + # gather across all procesess + gathered_x = trainer.accelerator.gather(padded_x) + + # mask out zeros + masked_x = gathered_x[gathered_x != 0] + + # if the tensor is empty, warn and return None + if len(masked_x) == 0: + click.secho( + f"The call to this method resulted in an empty tensor after masking out zeros. The gathered tensor was this: {gathered_x} and the original value passed was: {x}.", + fg="red", + ) + return None + + if method == "mean": + return torch.mean(masked_x) + elif method == "sum": + return torch.sum(masked_x) + elif method == "max": + return torch.max(masked_x) + elif method == "min": + return torch.min(masked_x) + + +def save_trainer( + tracker: Tracker, + trainer: DiffusionPriorTrainer, + is_latest: bool, + is_best: bool, + epoch: int, + samples_seen: int, + best_validation_loss: float, +): + """ + Logs the model with an appropriate method depending on the tracker + """ + trainer.accelerator.wait_for_everyone() + + if trainer.accelerator.is_main_process: + click.secho( + f"RANK:{trainer.accelerator.process_index} | Saving Model | Best={is_best} | Latest={is_latest}", + fg="magenta", + ) + + tracker.save( + trainer=trainer, + is_best=is_best, + is_latest=is_latest, + epoch=int(epoch), + samples_seen=int(samples_seen), + best_validation_loss=best_validation_loss, + ) + + +def recall_trainer(tracker: Tracker, trainer: DiffusionPriorTrainer): + """ + Loads the model with an appropriate method depending on the tracker + """ + + if trainer.accelerator.is_main_process: + click.secho(f"Loading model from {type(tracker.loader).__name__}", fg="yellow") + + state_dict = tracker.recall() + + trainer.load(state_dict, strict=True) + + return ( + int(state_dict.get("epoch", 0)), + state_dict.get("best_validation_loss", 0), + int(state_dict.get("samples_seen", 0)), + ) + + # eval functions -def eval_model( +def report_validation_loss( trainer: DiffusionPriorTrainer, dataloader: DataLoader, text_conditioned: bool, + use_ema: bool, + tracker: Tracker, + split: str, + tracker_folder: str, loss_type: str, - tracker_context: str, - tracker: BaseTracker = None, - use_ema: bool = True, ): - trainer.eval() - if trainer.is_main_process(): - click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True) + """ + Compute the validation loss on a given subset of data. + """ - with torch.no_grad(): - total_loss = 0.0 - total_samples = 0.0 + if trainer.accelerator.is_main_process: + click.secho( + f"Measuring performance on {use_ema}-{split} split", + fg="green", + blink=True, + ) - for image_embeddings, text_data in dataloader: - image_embeddings = image_embeddings.to(trainer.device) - text_data = text_data.to(trainer.device) + total_loss = torch.zeros(1, dtype=torch.float, device=trainer.device) - batches = image_embeddings.shape[0] + for image_embeddings, text_data in dataloader: + image_embeddings = image_embeddings.to(trainer.device) + text_data = text_data.to(trainer.device) - input_args = dict(image_embed=image_embeddings) + input_args = dict(image_embed=image_embeddings) - if text_conditioned: - input_args = dict(**input_args, text=text_data) - else: - input_args = dict(**input_args, text_embed=text_data) + if text_conditioned: + input_args = dict(**input_args, text=text_data) + else: + input_args = dict(**input_args, text_embed=text_data) - if use_ema: - loss = trainer.ema_diffusion_prior(**input_args) - else: - loss = trainer(**input_args) + if use_ema: + loss = trainer.ema_diffusion_prior(**input_args) + else: + loss = trainer(**input_args) - total_loss += loss * batches - total_samples += batches + total_loss += loss - avg_loss = total_loss / total_samples + # compute the average loss across all processes - stats = {f"{tracker_context}-{loss_type}": avg_loss} - trainer.print(stats) + avg_loss = pad_gather_reduce(trainer, total_loss, method="mean") + stats = {f"{tracker_folder}/{loss_type}-loss": avg_loss} - if exists(tracker): - tracker.log(stats, step=trainer.step.item() + 1) + # print and log results on main process + tracker.log(stats, step=trainer.step.item() + 1) + + return avg_loss def report_cosine_sims( trainer: DiffusionPriorTrainer, dataloader: DataLoader, text_conditioned: bool, - tracker: BaseTracker, - tracker_context: str = "validation", + tracker: Tracker, + split: str, + timesteps: int, + tracker_folder: str, ): trainer.eval() - if trainer.is_main_process(): - click.secho("Measuring Cosine-Similarity", fg="green", blink=True) + if trainer.accelerator.is_main_process: + click.secho( + f"Measuring Cosine-Similarity on {split} split with {timesteps} timesteps", + fg="green", + blink=True, + ) for test_image_embeddings, text_data in dataloader: test_image_embeddings = test_image_embeddings.to(trainer.device) @@ -127,9 +273,7 @@ def report_cosine_sims( # we are text conditioned, we produce an embedding from the tokenized text if text_conditioned: text_embedding, text_encodings = trainer.embed_text(text_data) - text_cond = dict( - text_embed=text_embedding, text_encodings=text_encodings - ) + text_cond = dict(text_embed=text_embedding, text_encodings=text_encodings) else: text_embedding = text_data text_cond = dict(text_embed=text_embedding) @@ -150,8 +294,7 @@ def report_cosine_sims( text_encodings_shuffled = None text_cond_shuffled = dict( - text_embed=text_embed_shuffled, - text_encodings=text_encodings_shuffled + text_embed=text_embed_shuffled, text_encodings=text_encodings_shuffled ) # prepare the text embedding @@ -164,7 +307,9 @@ def report_cosine_sims( # predict on the unshuffled text embeddings predicted_image_embeddings = trainer.p_sample_loop( - test_image_embeddings.shape, text_cond + test_image_embeddings.shape, + text_cond, + timesteps=timesteps, ) predicted_image_embeddings = ( @@ -174,7 +319,9 @@ def report_cosine_sims( # predict on the shuffled embeddings predicted_unrelated_embeddings = trainer.p_sample_loop( - test_image_embeddings.shape, text_cond_shuffled + test_image_embeddings.shape, + text_cond_shuffled, + timesteps=timesteps, ) predicted_unrelated_embeddings = ( @@ -183,32 +330,97 @@ def report_cosine_sims( ) # calculate similarities - original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy() - predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy() - unrelated_similarity = ( - cos(text_embed, predicted_unrelated_embeddings).cpu().numpy() + orig_sim = pad_gather_reduce( + trainer, cos(text_embed, test_image_embeddings), method="mean" ) - predicted_img_similarity = ( - cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy() + pred_sim = pad_gather_reduce( + trainer, cos(text_embed, predicted_image_embeddings), method="mean" + ) + unrel_sim = pad_gather_reduce( + trainer, cos(text_embed, predicted_unrelated_embeddings), method="mean" + ) + pred_img_sim = pad_gather_reduce( + trainer, + cos(test_image_embeddings, predicted_image_embeddings), + method="mean", ) stats = { - f"{tracker_context}/baseline similarity": np.mean(original_similarity), - f"{tracker_context}/similarity with text": np.mean(predicted_similarity), - f"{tracker_context}/similarity with original image": np.mean( - predicted_img_similarity - ), - f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity), - f"{tracker_context}/difference from baseline similarity": np.mean( - predicted_similarity - original_similarity - ), + f"{tracker_folder}/baseline similarity [steps={timesteps}]": orig_sim, + f"{tracker_folder}/similarity with text [steps={timesteps}]": pred_sim, + f"{tracker_folder}/similarity with original image [steps={timesteps}]": pred_img_sim, + f"{tracker_folder}/similarity with unrelated caption [steps={timesteps}]": unrel_sim, + f"{tracker_folder}/difference from baseline similarity [steps={timesteps}]": pred_sim + - orig_sim, } - for k, v in stats.items(): - trainer.print(f"{tracker_context}/{k}: {v}") + tracker.log(stats, step=trainer.step.item() + 1) - if exists(tracker): - tracker.log(stats, step=trainer.step.item() + 1) + +def eval_model( + trainer: DiffusionPriorTrainer, + dataloader: DataLoader, + text_conditioned: bool, + split: str, + tracker: Tracker, + use_ema: bool, + report_cosine: bool, + report_loss: bool, + timesteps: List[int], + loss_type: str = None, +): + """ + Run evaluation on a model and track metrics + + returns: loss if requested + """ + trainer.eval() + + use_ema = "ema" if use_ema else "online" + tracker_folder = f"metrics/{use_ema}-{split}" + + # detemine if valid timesteps are passed + + min_timesteps = trainer.accelerator.unwrap_model( + trainer.diffusion_prior + ).sample_timesteps + max_timesteps = trainer.accelerator.unwrap_model( + trainer.diffusion_prior + ).noise_scheduler.num_timesteps + + assert all_between( + timesteps, lower_bound=min_timesteps, upper_bound=max_timesteps + ), f"all timesteps values must be between {min_timesteps} and {max_timesteps}: got {timesteps}" + + # measure cosine metrics across various eta and timesteps + + if report_cosine: + for timestep in timesteps: + report_cosine_sims( + trainer, + dataloader=dataloader, + text_conditioned=text_conditioned, + tracker=tracker, + split=split, + timesteps=timestep, + tracker_folder=tracker_folder, + ) + + # measure loss on a seperate split of data + + if report_loss: + loss = report_validation_loss( + trainer=trainer, + dataloader=dataloader, + text_conditioned=text_conditioned, + use_ema=use_ema, + tracker=tracker, + split=split, + tracker_folder=tracker_folder, + loss_type=loss_type, + ) + + return loss # training script @@ -216,182 +428,327 @@ def report_cosine_sims( def train( trainer: DiffusionPriorTrainer, + tracker: Tracker, train_loader: DataLoader, eval_loader: DataLoader, test_loader: DataLoader, config: DiffusionPriorTrainConfig, ): - # distributed tracking with wandb - if trainer.accelerator.num_processes > 1: - os.environ["WANDB_START_METHOD"] = "thread" + # init timers + save_timer = Timer() # when to save + samples_timer = Timer() # samples/sec + validation_profiler = Timer() # how long is validation taking + validation_countdown = Timer() # when to perform evalutation - tracker = wandb.init( - name=f"RANK:{trainer.device}", - entity=config.tracker.wandb_entity, - project=config.tracker.wandb_project, - config=config.dict(), - group=GROUP, - ) + # keep track of best validation loss - # sync after tracker init - trainer.wait_for_everyone() - - # init a timer - timer = Timer() + best_validation_loss = config.train.best_validation_loss + samples_seen = config.train.num_samples_seen # do training - for img, txt in train_loader: - trainer.train() - current_step = trainer.step.item() + 1 - # place data on device - img = img.to(trainer.device) - txt = txt.to(trainer.device) + start_epoch = config.train.current_epoch - # pass to model - loss = trainer(text=txt, image_embed=img) + for epoch in range(start_epoch, config.train.epochs): + # if we finished out an old epoch, reset the distribution to be a full epoch + tracker.log({"tracking/epoch": epoch}, step=trainer.step.item()) - # display & log loss (will only print from main process) - trainer.print(f"Step {current_step}: Loss {loss}") + if train_loader.dataset.get_start() > 0 and epoch == start_epoch+1: + if trainer.accelerator.is_main_process: + click.secho(f"Finished resumed epoch...resetting dataloader.") + train_loader.dataset.set_start(0) - # perform backprop & apply EMA updates - trainer.update() + for img, txt in train_loader: + # setup things every step - # track samples/sec/rank - samples_per_sec = img.shape[0] / timer.elapsed() + trainer.train() + current_step = trainer.step.item() + samples_timer.reset() - # samples seen - samples_seen = ( - config.data.batch_size * trainer.accelerator.num_processes * current_step - ) + # place data on device - # ema decay - ema_decay = trainer.ema_diffusion_prior.get_current_decay() + img = img.to(trainer.device) + txt = txt.to(trainer.device) - # Log on all processes for debugging - tracker.log( - { - "tracking/samples-sec": samples_per_sec, - "tracking/samples-seen": samples_seen, - "tracking/ema-decay": ema_decay, - "metrics/training-loss": loss, - }, - step=current_step, - ) + # pass to model - # Metric Tracking & Checkpointing (outside of timer's scope) - if current_step % EVAL_EVERY == 0: - eval_model( - trainer=trainer, - dataloader=eval_loader, - text_conditioned=config.prior.condition_on_text_encodings, - loss_type=config.prior.loss_type, - tracker_context="metrics/online-model-validation", - tracker=tracker, - use_ema=False, + loss = trainer(text=txt, image_embed=img) + + # perform backprop & apply EMA updates + + trainer.update() + + # gather info about training step + + all_loss = pad_gather_reduce(trainer, loss, method="mean") + num_samples = pad_gather_reduce(trainer, len(txt), method="sum") + samples_per_sec = num_samples / samples_timer.elapsed() + samples_seen += num_samples + ema_decay = trainer.ema_diffusion_prior.get_current_decay() + + # log + + tracker.log( + { + "tracking/samples-sec": samples_per_sec, + "tracking/samples-seen": samples_seen, + "tracking/ema-decay": ema_decay, + f"tracking/training-{config.prior.loss_type}": all_loss, + }, + step=current_step, ) - eval_model( - trainer=trainer, - dataloader=eval_loader, - text_conditioned=config.prior.condition_on_text_encodings, - loss_type=config.prior.loss_type, - tracker_context="metrics/ema-model-validation", - tracker=tracker, - use_ema=True, + # Metric Tracking @ Timed Intervals + + eval_delta = pad_gather_reduce( + trainer, validation_countdown.elapsed(), method="min" ) - report_cosine_sims( - trainer=trainer, - dataloader=eval_loader, - text_conditioned=config.prior.condition_on_text_encodings, - tracker=tracker, - tracker_context="metrics", - ) + if eval_delta != None and eval_delta > config.data.eval_every_seconds: + # begin timing how long this takes - if current_step % config.train.save_every == 0: - trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth") + validation_profiler.reset() - # reset timer for next round - timer.reset() + # package kwargs for evaluation + + eval_kwargs = { + "trainer": trainer, + "tracker": tracker, + "text_conditioned": config.prior.condition_on_text_encodings, + "timesteps": config.train.eval_timesteps, + } + + # ONLINE MODEL : COSINE : LOSS : VALIDATION SPLIT + + eval_model( + dataloader=eval_loader, + loss_type=config.prior.loss_type, + split="validation", + use_ema=False, + report_cosine=False, + report_loss=True, + **eval_kwargs, + ) + + # EMA MODEL : COSINE : LOSS : VALIDATION DATA + + ema_val_loss = eval_model( + dataloader=eval_loader, + loss_type=config.prior.loss_type, + split="validation", + use_ema=True, + report_cosine=True, + report_loss=True, + **eval_kwargs, + ) + + tracker.log( + { + "tracking/validation length (minutes)": validation_profiler.elapsed() + / 60 + } + ) + + # check if the ema validation is the lowest seen yet + + if ema_val_loss < best_validation_loss: + best_validation_loss = ema_val_loss + + # go save the model as best + + save_trainer( + trainer=trainer, + tracker=tracker, + is_best=True, + is_latest=False, + samples_seen=samples_seen, + epoch=epoch, + best_validation_loss=best_validation_loss, + ) + + # reset timer for validaiton + + validation_countdown.reset() + + elif eval_delta is None: + click.secho( + f"Error occured reading the eval time on rank: {trainer.device}", + fg="yellow", + ) + + # save as latest model on schedule + + save_delta = pad_gather_reduce(trainer, save_timer.elapsed(), method="min") + + if save_delta != None and save_delta >= config.train.save_every_seconds: + save_trainer( + trainer=trainer, + tracker=tracker, + is_best=False, + is_latest=True, + samples_seen=samples_seen, + epoch=epoch, + best_validation_loss=best_validation_loss, + ) + + save_timer.reset() + + elif save_delta is None: + click.secho( + f"Error occured reading the save time on rank: {trainer.device}", + fg="yellow", + ) # evaluate on test data - eval_model( + if trainer.accelerator.is_main_process: + click.secho(f"Starting Test", fg="red") + + # save one last time as latest before beginning validation + + save_trainer( + tracker=tracker, + trainer=trainer, + is_best=False, + is_latest=True, + samples_seen=samples_seen, + epoch=epoch, + best_validation_loss=best_validation_loss, + ) + + test_loss = eval_model( trainer=trainer, dataloader=test_loader, text_conditioned=config.prior.condition_on_text_encodings, - loss_type=config.prior.loss_type, - tracker_context="test", + split="test", tracker=tracker, + use_ema=True, + report_cosine=False, + report_loss=True, + timesteps=config.train.eval_timesteps, + loss_type=config.prior.loss_type, ) - report_cosine_sims( - trainer, - test_loader, - config.prior.condition_on_text_encodings, - tracker, - tracker_context="test", - ) + if test_loss < best_validation_loss: + best_validation_loss = test_loss + + # go save the model as best + + save_trainer( + trainer=trainer, + tracker=tracker, + is_best=True, + is_latest=False, + samples_seen=samples_seen, + epoch=epoch, + best_validation_loss=test_loss, + ) -def initialize_training(config, accelerator=None): +def initialize_training(config_file, accelerator): """ Parse the configuration file, and prepare everything necessary for training """ + # load the configuration file + if accelerator.is_main_process: + click.secho(f"Loading configuration from {config_file}", fg="green") + + config = TrainDiffusionPriorConfig.from_json_path(config_file) + + # seed + + set_seed(config.train.random_seed) # get a device - if accelerator: - device = accelerator.device - click.secho(f"Accelerating on: {device}", fg="yellow") - else: - if torch.cuda.is_available(): - click.secho("GPU detected, defaulting to cuda:0", fg="yellow") - device = "cuda:0" - else: - click.secho("No GPU detected...using cpu", fg="yellow") - device = "cpu" + device = accelerator.device # make the trainer (will automatically distribute if possible & configured) - trainer = make_model(config.prior, config.train, device, accelerator).to(device) + trainer: DiffusionPriorTrainer = make_model( + config.prior, config.train, device, accelerator + ).to(device) + + # create a tracker + + tracker = create_tracker( + accelerator, config, config_file, dummy=accelerator.process_index != 0 + ) # reload from chcekpoint - if config.load.resume == True: - click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan") - trainer.load(config.load.source) + if tracker.can_recall: + current_epoch, best_validation_loss, samples_seen = recall_trainer( + tracker=tracker, trainer=trainer + ) + + # display best values + if trainer.accelerator.is_main_process: + click.secho(f"Current Epoch: {current_epoch} | Best Val Loss: {best_validation_loss} | Samples Seen: {samples_seen}", fg="yellow") + + # update config to reflect recalled values + config.train.num_samples_seen = samples_seen + config.train.current_epoch = current_epoch + config.train.best_validation_loss = best_validation_loss # fetch and prepare data - if trainer.is_main_process(): - click.secho("Grabbing data from source", fg="blue", blink=True) + if trainer.accelerator.is_main_process: + click.secho("Grabbing data...", fg="blue", blink=True) + trainer.accelerator.wait_for_everyone() img_reader = get_reader( text_conditioned=trainer.text_conditioned, img_url=config.data.image_url, meta_url=config.data.meta_url, ) + # calculate start point within epoch + + trainer.accelerator.wait_for_everyone() + train_loader, eval_loader, test_loader = make_splits( text_conditioned=trainer.text_conditioned, batch_size=config.data.batch_size, - num_data_points=NUM_DATA_POINTS, + num_data_points=config.data.num_data_points, train_split=config.data.splits.train, eval_split=config.data.splits.val, image_reader=img_reader, - rank=accelerator.state.process_index if exists(accelerator) else 0, - world_size=accelerator.state.num_processes if exists(accelerator) else 1, - start=START, + rank=accelerator.state.process_index, + world_size=accelerator.state.num_processes, + start=0, ) - # wait for everyone to load data before continuing - trainer.wait_for_everyone() + # update the start point to finish out the epoch on a resumed run + + if tracker.can_recall: + samples_seen = config.train.num_samples_seen + length = ( + config.data.num_data_points + if samples_seen <= img_reader.count + else img_reader.count + ) + scaled_samples = length * config.train.current_epoch + start_point = ( + scaled_samples - samples_seen if scaled_samples > samples_seen else samples_seen + ) + + if trainer.accelerator.is_main_process: + click.secho(f"Resuming at sample: {start_point}", fg="yellow") + + train_loader.dataset.set_start(start_point) # start training + + if trainer.accelerator.is_main_process: + click.secho( + f"Beginning Prior Training : Distributed={accelerator.state.distributed_type != accelerate_dataclasses.DistributedType.NO}", + fg="yellow", + ) + train( trainer=trainer, + tracker=tracker, train_loader=train_loader, eval_loader=eval_loader, test_loader=test_loader, @@ -400,23 +757,13 @@ def initialize_training(config, accelerator=None): @click.command() -@click.option("--hfa", default=True) -@click.option("--config_path", default="configs/prior.json") -def main(hfa, config_path): - # start HFA if requested - if hfa: - accelerator = Accelerator() - else: - accelerator = None +@click.option("--config_file", default="configs/train_prior_config.example.json") +def main(config_file): + # start HFA + accelerator = Accelerator() - # load the configuration file on main process - if not exists(accelerator) or accelerator.is_main_process: - click.secho(f"Loading configuration from {config_path}", fg="green") - - config = TrainDiffusionPriorConfig.from_json_path(config_path) - - # send config to get processed - initialize_training(config, accelerator) + # setup training + initialize_training(config_file, accelerator) if __name__ == "__main__":