diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index f9fc1d1..6cc609b 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -14,6 +14,8 @@ from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.version import __version__ from packaging import version +from accelerate import Accelerator + import numpy as np # helper functions @@ -189,13 +191,13 @@ class EMA(nn.Module): By adjusting the power, you can control how fast EMA will ramp up to your specified beta. @crowsonkb's notes on EMA Warmup: - + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 215.4k steps). - + Args: inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. power (float): Exponential factor of EMA warmup. Default: 1. @@ -205,7 +207,7 @@ class EMA(nn.Module): self, model, beta = 0.9999, - update_after_step = 10000, + update_after_step = 100, update_every = 10, inv_gamma = 1.0, power = 2/3, @@ -280,6 +282,7 @@ class EMA(nn.Module): def __call__(self, *args, **kwargs): return self.ema_model(*args, **kwargs) + # diffusion prior trainer def prior_sample_in_chunks(fn): @@ -303,88 +306,189 @@ class DiffusionPriorTrainer(nn.Module): max_grad_norm = None, amp = False, group_wd_params = True, + device = None, + accelerator = None, **kwargs ): super().__init__() assert isinstance(diffusion_prior, DiffusionPrior) + assert not exists(accelerator) or isinstance(accelerator, Accelerator) + assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device." ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) + # assign some helpful member vars + self.accelerator = accelerator + self.device = accelerator.device if exists(accelerator) else device + self.text_conditioned = diffusion_prior.condition_on_text_encodings + + # save model + self.diffusion_prior = diffusion_prior - # exponential moving average - - self.use_ema = use_ema - if self.use_ema: - self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs) - # optimizer and mixed precision stuff self.amp = amp self.scaler = GradScaler(enabled = amp) + self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params) + self.optimizer = get_optimizer( - diffusion_prior.parameters(), - lr = lr, - wd = wd, - eps = eps, - group_wd_params = group_wd_params, + self.diffusion_prior.parameters(), + **self.optim_kwargs, **kwargs ) + # distribute the model if using HFA + if exists(self.accelerator): + self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer) + + # exponential moving average stuff + + self.use_ema = use_ema + + if self.use_ema: + self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs) + # gradient clipping if needed self.max_grad_norm = max_grad_norm + # track steps internally + self.register_buffer('step', torch.tensor([0])) + # accelerator wrappers + + def print(self, msg): + if exists(self.accelerator): + self.accelerator.print(msg) + else: + print(msg) + + def unwrap_model(self, model): + if exists(self.accelerator): + return self.accelerator.unwrap_model(model) + else: + return model + + def wait_for_everyone(self): + if exists(self.accelerator): + self.accelerator.wait_for_everyone() + + def is_main_process(self): + if exists(self.accelerator): + return self.accelerator.is_main_process + else: + return True + + def clip_grad_norm_(self, *args): + if exists(self.accelerator): + return self.accelerator.clip_grad_norm_(*args) + else: + return torch.nn.utils.clip_grad_norm_(*args) + + def backprop(self, x): + if exists(self.accelerator): + self.accelerator.backward(x) + else: + try: + x.backward() + except Exception as e: + self.print(f"Caught error in backprop call: {e}") + + # utility + def save(self, path, overwrite = True, **kwargs): - path = Path(path) - assert not (path.exists() and not overwrite) - path.parent.mkdir(parents = True, exist_ok = True) + # ensure we sync gradients before continuing + self.wait_for_everyone() - save_obj = dict( - scaler = self.scaler.state_dict(), - optimizer = self.optimizer.state_dict(), - model = self.diffusion_prior.state_dict(), - version = __version__, - step = self.step.item(), - **kwargs - ) + # only save on the main process + if self.is_main_process(): + self.print(f"Saving checkpoint at step: {self.step.item()}") + path = Path(path) + assert not (path.exists() and not overwrite) + path.parent.mkdir(parents = True, exist_ok = True) - if self.use_ema: - save_obj = {**save_obj, 'ema': self.ema_diffusion_prior.state_dict()} + save_obj = dict( + scaler = self.scaler.state_dict(), + optimizer = self.optimizer.state_dict(), + model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable + version = version.parse(__version__), + step = self.step.item(), + **kwargs + ) - torch.save(save_obj, str(path)) + if self.use_ema: + save_obj = { + **save_obj, + 'ema': self.ema_diffusion_prior.state_dict(), + 'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # save the ema model specifically for easy ema-only reload + } - def load(self, path, only_model = False, strict = True): + torch.save(save_obj, str(path)) + + def load(self, path, overwrite_lr = True, strict = True): + """ + Load a checkpoint of a diffusion prior trainer. + + Will load the entire trainer, including the optimizer and EMA. + + Params: + - path (str): a path to the DiffusionPriorTrainer checkpoint file + - overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer + - strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match + + Returns: + loaded_obj (dict): The loaded checkpoint dictionary + """ + + # all processes need to load checkpoint. no restriction here path = Path(path) assert path.exists() - loaded_obj = torch.load(str(path)) + loaded_obj = torch.load(str(path), map_location=self.device) if version.parse(__version__) != loaded_obj['version']: print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}') - self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict) + # unwrap the model when loading from checkpoint + self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict) self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) - if only_model: - return loaded_obj - self.scaler.load_state_dict(loaded_obj['scaler']) self.optimizer.load_state_dict(loaded_obj['optimizer']) + if overwrite_lr: + new_lr = self.optim_kwargs["lr"] + + self.print(f"Overriding LR to be {new_lr}") + + for group in self.optimizer.param_groups: + group["lr"] = new_lr + if self.use_ema: assert 'ema' in loaded_obj self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict) + # below not be necessary, but I had a suspicion that this wasn't being loaded correctly + self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"]) + + # sync and inform + self.wait_for_everyone() + self.print(f"Loaded model") return loaded_obj + # model functionality + def update(self): + # only continue with updates until all ranks finish + self.wait_for_everyone() + if exists(self.max_grad_norm): self.scaler.unscale_(self.optimizer) - nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm) + # utilize HFA clipping where applicable + self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm) self.scaler.step(self.optimizer) self.scaler.update() @@ -399,17 +503,32 @@ class DiffusionPriorTrainer(nn.Module): @cast_torch_tensor @prior_sample_in_chunks def p_sample_loop(self, *args, **kwargs): - return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs) + if self.use_ema: + return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs) + else: + return self.diffusion_prior.p_sample_loop(*args, **kwargs) @torch.no_grad() @cast_torch_tensor @prior_sample_in_chunks def sample(self, *args, **kwargs): - return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs) + if self.use_ema: + return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs) + else: + return self.diffusion_prior.sample(*args, **kwargs) @torch.no_grad() def sample_batch_size(self, *args, **kwargs): - return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs) + if self.use_ema: + return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs) + else: + return self.diffusion_prior.sample_batch_size(*args, **kwargs) + + @torch.no_grad() + @cast_torch_tensor + @prior_sample_in_chunks + def embed_text(self, *args, **kwargs): + return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs) @cast_torch_tensor def forward( @@ -427,8 +546,10 @@ class DiffusionPriorTrainer(nn.Module): total_loss += loss.item() + # backprop with accelerate if applicable + if self.training: - self.scaler.scale(loss).backward() + self.backprop(self.scaler.scale(loss)) return total_loss diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 76df513..f5e0a15 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -1,77 +1,136 @@ -from pathlib import Path +# TODO: add start, num_data_points, eval_every and group to config +# TODO: switch back to repo's wandb + +START = 0 +NUM_DATA_POINTS = 250e6 +EVAL_EVERY = 1000 +GROUP = "distributed" + +import os import click -import math -import numpy as np +import wandb import torch -import clip from torch import nn +from torch.nn.functional import normalize +from torch.utils.data import DataLoader -from dalle2_pytorch.dataloaders import make_splits, get_reader -from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter -from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model +import numpy as np -from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker -from dalle2_pytorch.utils import Timer, print_ribbon +from accelerate import Accelerator -from tqdm import tqdm +from dalle2_pytorch.dataloaders import get_reader, make_splits +from dalle2_pytorch.utils import Timer +from dalle2_pytorch.train_configs import ( + DiffusionPriorTrainConfig, + TrainDiffusionPriorConfig, +) +from dalle2_pytorch.trackers import BaseTracker, WandbTracker +from dalle2_pytorch import DiffusionPriorTrainer -# constants -REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training +# helpers -tracker = WandbTracker() -# helpers functions +cos = nn.CosineSimilarity(dim=1, eps=1e-6) + def exists(val): - val is not None + return val is not None -# functions -def eval_model(model, dataloader, text_conditioned, loss_type, device, phase="Validation",): - model.eval() +def make_model( + prior_config, train_config, device: str = None, accelerator: Accelerator = None +): + # create model from config + diffusion_prior = prior_config.create() + + # instantiate the trainer + trainer = DiffusionPriorTrainer( + diffusion_prior=diffusion_prior, + lr=train_config.lr, + wd=train_config.wd, + max_grad_norm=train_config.max_grad_norm, + amp=train_config.amp, + use_ema=train_config.use_ema, + device=device, + accelerator=accelerator, + ) + + return trainer + + +# eval functions + + +def eval_model( + trainer: DiffusionPriorTrainer, + dataloader: DataLoader, + text_conditioned: bool, + loss_type: str, + phase: str, + tracker: BaseTracker = None, + use_ema: bool = True, +): + trainer.eval() + if trainer.is_main_process(): + click.secho(f"Measuring performance on {phase}", fg="green", blink=True) with torch.no_grad(): - total_loss = 0. - total_samples = 0. + total_loss = 0.0 + total_samples = 0.0 - for image_embeddings, text_data in tqdm(dataloader): - image_embeddings = image_embeddings.to(device) - text_data = text_data.to(device) + for image_embeddings, text_data in dataloader: + image_embeddings = image_embeddings.to(trainer.device) + text_data = text_data.to(trainer.device) batches = image_embeddings.shape[0] input_args = dict(image_embed=image_embeddings) + if text_conditioned: - input_args = dict(**input_args, text = text_data) + input_args = dict(**input_args, text=text_data) else: input_args = dict(**input_args, text_embed=text_data) - loss = model(**input_args) + if use_ema: + loss = trainer.ema_diffusion_prior(**input_args) + else: + loss = trainer(**input_args) total_loss += loss * batches total_samples += batches - avg_loss = (total_loss / total_samples) + avg_loss = total_loss / total_samples - tracker.log({f'{phase} {loss_type}': avg_loss}) + stats = {f"{phase}/{loss_type}": avg_loss} + trainer.print(stats) -def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device): - diffusion_prior.eval() + if exists(tracker): + tracker.log(stats, step=trainer.step.item() + 1) - cos = nn.CosineSimilarity(dim=1, eps=1e-6) - for test_image_embeddings, text_data in tqdm(dataloader): - test_image_embeddings = test_image_embeddings.to(device) - text_data = text_data.to(device) +def report_cosine_sims( + trainer: DiffusionPriorTrainer, + dataloader: DataLoader, + text_conditioned: bool, + tracker: BaseTracker, + phase: str = "validation", +): + trainer.eval() + if trainer.is_main_process(): + click.secho("Measuring Cosine-Similarity", fg="green", blink=True) + + for test_image_embeddings, text_data in dataloader: + test_image_embeddings = test_image_embeddings.to(trainer.device) + text_data = text_data.to(trainer.device) # we are text conditioned, we produce an embedding from the tokenized text if text_conditioned: - text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text( - text_data) - text_cond = dict(text_embed=text_embedding, - text_encodings=text_encodings, mask=text_mask) + text_embedding, text_encodings, text_mask = trainer.embed_text(text_data) + text_cond = dict( + text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask + ) else: text_embedding = text_data text_cond = dict(text_embed=text_embedding) @@ -82,8 +141,7 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device): # roll the text to simulate "unrelated" captions rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1) text_embed_shuffled = text_embed_shuffled[rolled_idx] - text_embed_shuffled = text_embed_shuffled / \ - text_embed_shuffled.norm(dim=1, keepdim=True) + text_embed_shuffled = text_embed_shuffled / normalize(text_embed_shuffled) if text_conditioned: text_encodings_shuffled = text_encodings[rolled_idx] @@ -92,294 +150,272 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device): text_encodings_shuffled = None text_mask_shuffled = None - text_cond_shuffled = dict(text_embed=text_embed_shuffled, - text_encodings=text_encodings_shuffled, mask=text_mask_shuffled) + text_cond_shuffled = dict( + text_embed=text_embed_shuffled, + text_encodings=text_encodings_shuffled, + mask=text_mask_shuffled, + ) # prepare the text embedding - text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True) + text_embed = normalize(text_embedding / text_embedding) # prepare image embeddings - test_image_embeddings = test_image_embeddings / \ - test_image_embeddings.norm(dim=1, keepdim=True) + test_image_embeddings = test_image_embeddings / normalize(test_image_embeddings) # predict on the unshuffled text embeddings - predicted_image_embeddings = diffusion_prior.p_sample_loop( - test_image_embeddings.shape, text_cond) - predicted_image_embeddings = predicted_image_embeddings / \ - predicted_image_embeddings.norm(dim=1, keepdim=True) + predicted_image_embeddings = trainer.p_sample_loop( + test_image_embeddings.shape, text_cond + ) + + predicted_image_embeddings = predicted_image_embeddings / normalize( + predicted_image_embeddings + ) # predict on the shuffled embeddings - predicted_unrelated_embeddings = diffusion_prior.p_sample_loop( - test_image_embeddings.shape, text_cond_shuffled) - predicted_unrelated_embeddings = predicted_unrelated_embeddings / \ - predicted_unrelated_embeddings.norm(dim=1, keepdim=True) + predicted_unrelated_embeddings = trainer.p_sample_loop( + test_image_embeddings.shape, text_cond_shuffled + ) + + predicted_unrelated_embeddings = predicted_unrelated_embeddings / normalize( + predicted_unrelated_embeddings + ) # calculate similarities - original_similarity = cos( - text_embed, test_image_embeddings).cpu().numpy() - predicted_similarity = cos( - text_embed, predicted_image_embeddings).cpu().numpy() - unrelated_similarity = cos( - text_embed, predicted_unrelated_embeddings).cpu().numpy() - predicted_img_similarity = cos( - test_image_embeddings, predicted_image_embeddings).cpu().numpy() - tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity), - "CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity), - "CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity), - "CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity), - "Cosine similarity difference":np.mean(predicted_similarity - original_similarity)}) + original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy() + predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy() + unrelated_similarity = ( + cos(text_embed, predicted_unrelated_embeddings).cpu().numpy() + ) + predicted_img_similarity = ( + cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy() + ) + + stats = { + f"{phase}/baseline similarity": np.mean(original_similarity), + f"{phase}/similarity with text": np.mean(predicted_similarity), + f"{phase}/similarity with original image": np.mean( + predicted_img_similarity + ), + f"{phase}/similarity with unrelated caption": np.mean(unrelated_similarity), + f"{phase}/difference from baseline similarity": np.mean( + predicted_similarity - original_similarity + ), + } + + for k, v in stats.items(): + trainer.print(f"{phase}/{k}: {v}") + + if exists(tracker): + tracker.log(stats, step=trainer.step.item() + 1) + + +# training script + + +def train( + trainer: DiffusionPriorTrainer, + train_loader: DataLoader, + eval_loader: DataLoader, + test_loader: DataLoader, + config: DiffusionPriorTrainConfig, +): + # distributed tracking with wandb + if trainer.accelerator.num_processes > 1: + os.environ["WANDB_START_METHOD"] = "thread" + + tracker = wandb.init( + name=f"RANK:{trainer.device}", + entity=config.tracker.wandb_entity, + project=config.tracker.wandb_project, + config=config.dict(), + group=GROUP, + ) + + # sync after tracker init + trainer.wait_for_everyone() + + # init a timer + timer = Timer() + + # do training + for img, txt in train_loader: + trainer.train() + current_step = trainer.step.item() + 1 + + # place data on device + img = img.to(trainer.device) + txt = txt.to(trainer.device) + + # pass to model + loss = trainer(text=txt, image_embed=img) + + # display & log loss (will only print from main process) + trainer.print(f"Step {current_step}: Loss {loss}") + + # perform backprop & apply EMA updates + trainer.update() + + # track samples/sec/rank + samples_per_sec = img.shape[0] / timer.elapsed() + + # samples seen + samples_seen = ( + config.data.batch_size * trainer.accelerator.num_processes * current_step + ) + + # ema decay + ema_decay = trainer.ema_diffusion_prior.get_current_decay() + + # Log on all processes for debugging + tracker.log( + { + "training/loss": loss, + "samples/sec/rank": samples_per_sec, + "samples/seen": samples_seen, + "ema/decay": ema_decay, + }, + step=current_step, + ) + + # Metric Tracking & Checkpointing (outside of timer's scope) + if current_step % EVAL_EVERY == 0: + eval_model( + trainer, + eval_loader, + config.prior.condition_on_text_encodings, + config.prior.loss_type, + "training/validation", + tracker, + use_ema=False, + ) + + eval_model( + trainer=trainer, + dataloader=eval_loader, + text_conditioned=config.prior.condition_on_text_encodings, + loss=config.prior.loss_type, + phase="ema/validation", + tracker=tracker, + use_ema=True, + ) + + report_cosine_sims( + trainer=trainer, + dataloader=eval_loader, + text_conditioned=config.prior.condition_on_text_encodings, + tracker=tracker, + phase="ema/validation", + ) + + if current_step % config.train.save_every == 0: + trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth") + + # reset timer for next round + timer.reset() + + # evaluate on test data + + eval_model( + trainer=trainer, + dataloader=test_loader, + text_conditioned=config.prior.condition_on_text_encodings, + loss_type=config.prior.loss_type, + phase="test", + tracker=tracker, + ) + + report_cosine_sims( + trainer, + test_loader, + config.prior.condition_on_text_encodings, + tracker, + phase="test", + ) + + +def initialize_training(config, accelerator=None): + """ + Parse the configuration file, and prepare everything necessary for training + """ + + # get a device + + if accelerator: + device = accelerator.device + click.secho(f"Accelerating on: {device}", fg="yellow") + else: + if torch.cuda.is_available(): + click.secho("GPU detected, defaulting to cuda:0", fg="yellow") + device = "cuda:0" + else: + click.secho("No GPU detected...using cpu", fg="yellow") + device = "cpu" + + # make the trainer (will automatically distribute if possible & configured) + + trainer = make_model(config.prior, config.train, device, accelerator).to(device) + + # reload from chcekpoint + + if config.load.resume == True: + click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan") + trainer.load(config.load.source) + + # fetch and prepare data + + if trainer.is_main_process(): + click.secho("Grabbing data from source", fg="blue", blink=True) + + img_reader = get_reader( + text_conditioned=trainer.text_conditioned, + img_url=config.data.image_url, + meta_url=config.data.meta_url, + ) + + train_loader, eval_loader, test_loader = make_splits( + text_conditioned=trainer.text_conditioned, + batch_size=config.data.batch_size, + num_data_points=NUM_DATA_POINTS, + train_split=config.data.splits.train, + eval_split=config.data.splits.val, + image_reader=img_reader, + rank=accelerator.state.process_index if exists(accelerator) else 0, + world_size=accelerator.state.num_processes if exists(accelerator) else 1, + start=START, + ) + + # wait for everyone to load data before continuing + trainer.wait_for_everyone() + + # start training + train( + trainer=trainer, + train_loader=train_loader, + eval_loader=eval_loader, + test_loader=test_loader, + config=config, + ) @click.command() -@click.option("--wandb-entity", default="laion") -@click.option("--wandb-project", default="diffusion-prior") -@click.option("--wandb-dataset", default="LAION-5B") -@click.option("--wandb-arch", default="DiffusionPrior") -@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") -@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") -@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/") -@click.option("--learning-rate", default=1.1e-4) -@click.option("--weight-decay", default=6.02e-2) -@click.option("--dropout", default=5e-2) -@click.option("--max-grad-norm", default=0.5) -@click.option("--num-data-points", default=250e6) -@click.option("--batch-size", default=320) -@click.option("--num-epochs", default=5) -@click.option("--image-embed-dim", default=768) -@click.option("--train-percent", default=0.9) -@click.option("--val-percent", default=1e-7) -@click.option("--test-percent", default=0.0999999) -@click.option("--dpn-depth", default=12) -@click.option("--dpn-dim-head", default=64) -@click.option("--dpn-heads", default=12) -@click.option("--dp-condition-on-text-encodings", default=True) -@click.option("--dp-timesteps", default=1000) -@click.option("--dp-normformer", default=True) -@click.option("--dp-cond-drop-prob", default=0.1) -@click.option("--dp-loss-type", default="l2") -@click.option("--clip", default="ViT-L/14") -@click.option("--amp", default=False) -@click.option("--save-interval", default=120) -@click.option("--save-path", default="./diffusion_prior_checkpoints") -@click.option("--pretrained-model-path", default=None) -@click.option("--gpu-device", default=0) -def train( - wandb_entity, - wandb_project, - wandb_dataset, - wandb_arch, - image_embed_url, - text_embed_url, - meta_url, - learning_rate, - weight_decay, - dropout, - max_grad_norm, - num_data_points, - batch_size, - num_epochs, - image_embed_dim, - train_percent, - val_percent, - test_percent, - dpn_depth, - dpn_dim_head, - dpn_heads, - dp_condition_on_text_encodings, - dp_timesteps, - dp_normformer, - dp_cond_drop_prob, - dp_loss_type, - clip, - amp, - save_interval, - save_path, - pretrained_model_path, - gpu_device -): - config = { - "learning_rate": learning_rate, - "architecture": wandb_arch, - "dataset": wandb_dataset, - "weight_decay": weight_decay, - "max_gradient_clipping_norm": max_grad_norm, - "batch_size": batch_size, - "epochs": num_epochs, - "diffusion_prior_network": { - "depth": dpn_depth, - "dim_head": dpn_dim_head, - "heads": dpn_heads, - "normformer": dp_normformer - }, - "diffusion_prior": { - "condition_on_text_encodings": dp_condition_on_text_encodings, - "timesteps": dp_timesteps, - "cond_drop_prob": dp_cond_drop_prob, - "loss_type": dp_loss_type, - "clip": clip - } - } - - # Check if DPRIOR_PATH exists(saved model path) - - DPRIOR_PATH = pretrained_model_path - RESUME = exists(DPRIOR_PATH) - - if not RESUME: - tracker.init( - entity = wandb_entity, - project = wandb_project, - config = config - ) - - # Obtain the utilized device. - - has_cuda = torch.cuda.is_available() - if has_cuda: - device = torch.device(f"cuda:{gpu_device}") - torch.cuda.set_device(device) - - # Training loop - # diffusion prior network - - prior_network = DiffusionPriorNetwork( - dim = image_embed_dim, - depth = dpn_depth, - dim_head = dpn_dim_head, - heads = dpn_heads, - attn_dropout = dropout, - ff_dropout = dropout, - normformer = dp_normformer - ) - - # Load clip model if text-conditioning - if dp_condition_on_text_encodings: - clip_adapter = OpenAIClipAdapter(clip) +@click.option("--hfa", default=True) +@click.option("--config_path", default="configs/prior.json") +def main(hfa, config_path): + # start HFA if requested + if hfa: + accelerator = Accelerator() else: - clip_adapter = None + accelerator = None - # diffusion prior with text embeddings and image embeddings pre-computed + # load the configuration file on main process + if not exists(accelerator) or accelerator.is_main_process: + click.secho(f"Loading configuration from {config_path}", fg="green") - diffusion_prior = DiffusionPrior( - net = prior_network, - clip = clip_adapter, - image_embed_dim = image_embed_dim, - timesteps = dp_timesteps, - cond_drop_prob = dp_cond_drop_prob, - loss_type = dp_loss_type, - condition_on_text_encodings = dp_condition_on_text_encodings - ) + config = TrainDiffusionPriorConfig.from_json_path(config_path) - # Load pre-trained model from DPRIOR_PATH - - if RESUME: - diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device) - tracker.init(entity = wandb_entity, project = wandb_project, config = config) - - # diffusion prior trainer - - trainer = DiffusionPriorTrainer( - diffusion_prior = diffusion_prior, - lr = learning_rate, - wd = weight_decay, - max_grad_norm = max_grad_norm, - amp = amp, - ).to(device) - - # load optimizer and scaler - - if RESUME: - trainer.optimizer.load_state_dict(loaded_obj['optimizer']) - trainer.scaler.load_state_dict(loaded_obj['scaler']) - - # Create save_path if it doesn't exist - - Path(save_path).mkdir(exist_ok = True, parents = True) - - # Utilize wrapper to abstract away loader logic - print_ribbon("Downloading Embeddings") - reader_args = dict(text_conditioned=dp_condition_on_text_encodings, img_url=image_embed_url) - - if dp_condition_on_text_encodings: - reader_args = dict(**reader_args, meta_url=meta_url) - img_reader = get_reader(**reader_args) - train_loader, eval_loader, test_loader = make_splits( - text_conditioned=dp_condition_on_text_encodings, - batch_size=batch_size, - num_data_points=num_data_points, - train_split=train_percent, - eval_split=val_percent, - image_reader=img_reader - ) - else: - reader_args = dict(**reader_args, txt_url=text_embed_url) - img_reader, txt_reader = get_reader(**reader_args) - train_loader, eval_loader, test_loader = make_splits( - text_conditioned=dp_condition_on_text_encodings, - batch_size=batch_size, - num_data_points=num_data_points, - train_split=train_percent, - eval_split=val_percent, - image_reader=img_reader, - text_reader=txt_reader - ) - - ### Training code ### - - step = 1 - timer = Timer() - epochs = num_epochs - - for _ in range(epochs): - - for image, text in tqdm(train_loader): - diffusion_prior.train() - - image = image.to(device) - text = text.to(device) - - input_args = dict(image_embed=image) - if dp_condition_on_text_encodings: - input_args = dict(**input_args, text = text) - else: - input_args = dict(**input_args, text_embed=text) - - loss = trainer(**input_args) - - # Samples per second - - samples_per_sec = batch_size * step / timer.elapsed() - - # Save checkpoint every save_interval minutes - if(int(timer.elapsed()) >= 60 * save_interval): - timer.reset() - - save_diffusion_model( - save_path, - diffusion_prior, - trainer.optimizer, - trainer.scaler, - config, - image_embed_dim) - - # Log to wandb - tracker.log({"Training loss": loss, - "Steps": step, - "Samples per second": samples_per_sec}) - # Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed) - # Use NUM_TEST_EMBEDDINGS samples from the test set each time - # Get embeddings from the most recently saved model - if(step % REPORT_METRICS_EVERY) == 0: - report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings, device=device) - ### Evaluate model(validation run) ### - eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation", device=device) - - step += 1 - trainer.update() - - ### Test run ### - eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test") + # send config to get processed + initialize_training(config, accelerator) if __name__ == "__main__": - train() + main()