Distributed Training of the Prior (#112)

* distributed prior trainer

better EMA support

update load and save methods of prior

* update prior training script

add test evalution & ema validation

add more tracking metrics

small cleanup
This commit is contained in:
zion
2022-06-19 10:46:14 -05:00
committed by GitHub
parent 6651eafa93
commit fe19b508ca
2 changed files with 506 additions and 349 deletions

View File

@@ -14,6 +14,8 @@ from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.version import __version__ from dalle2_pytorch.version import __version__
from packaging import version from packaging import version
from accelerate import Accelerator
import numpy as np import numpy as np
# helper functions # helper functions
@@ -205,7 +207,7 @@ class EMA(nn.Module):
self, self,
model, model,
beta = 0.9999, beta = 0.9999,
update_after_step = 10000, update_after_step = 100,
update_every = 10, update_every = 10,
inv_gamma = 1.0, inv_gamma = 1.0,
power = 2/3, power = 2/3,
@@ -280,6 +282,7 @@ class EMA(nn.Module):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs) return self.ema_model(*args, **kwargs)
# diffusion prior trainer # diffusion prior trainer
def prior_sample_in_chunks(fn): def prior_sample_in_chunks(fn):
@@ -303,88 +306,189 @@ class DiffusionPriorTrainer(nn.Module):
max_grad_norm = None, max_grad_norm = None,
amp = False, amp = False,
group_wd_params = True, group_wd_params = True,
device = None,
accelerator = None,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior) assert isinstance(diffusion_prior, DiffusionPrior)
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
# assign some helpful member vars
self.accelerator = accelerator
self.device = accelerator.device if exists(accelerator) else device
self.text_conditioned = diffusion_prior.condition_on_text_encodings
# save model
self.diffusion_prior = diffusion_prior self.diffusion_prior = diffusion_prior
# exponential moving average
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
# optimizer and mixed precision stuff # optimizer and mixed precision stuff
self.amp = amp self.amp = amp
self.scaler = GradScaler(enabled = amp) self.scaler = GradScaler(enabled = amp)
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
self.optimizer = get_optimizer( self.optimizer = get_optimizer(
diffusion_prior.parameters(), self.diffusion_prior.parameters(),
lr = lr, **self.optim_kwargs,
wd = wd,
eps = eps,
group_wd_params = group_wd_params,
**kwargs **kwargs
) )
# distribute the model if using HFA
if exists(self.accelerator):
self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer)
# exponential moving average stuff
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs)
# gradient clipping if needed # gradient clipping if needed
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
# track steps internally
self.register_buffer('step', torch.tensor([0])) self.register_buffer('step', torch.tensor([0]))
# accelerator wrappers
def print(self, msg):
if exists(self.accelerator):
self.accelerator.print(msg)
else:
print(msg)
def unwrap_model(self, model):
if exists(self.accelerator):
return self.accelerator.unwrap_model(model)
else:
return model
def wait_for_everyone(self):
if exists(self.accelerator):
self.accelerator.wait_for_everyone()
def is_main_process(self):
if exists(self.accelerator):
return self.accelerator.is_main_process
else:
return True
def clip_grad_norm_(self, *args):
if exists(self.accelerator):
return self.accelerator.clip_grad_norm_(*args)
else:
return torch.nn.utils.clip_grad_norm_(*args)
def backprop(self, x):
if exists(self.accelerator):
self.accelerator.backward(x)
else:
try:
x.backward()
except Exception as e:
self.print(f"Caught error in backprop call: {e}")
# utility
def save(self, path, overwrite = True, **kwargs): def save(self, path, overwrite = True, **kwargs):
path = Path(path) # ensure we sync gradients before continuing
assert not (path.exists() and not overwrite) self.wait_for_everyone()
path.parent.mkdir(parents = True, exist_ok = True)
save_obj = dict( # only save on the main process
scaler = self.scaler.state_dict(), if self.is_main_process():
optimizer = self.optimizer.state_dict(), self.print(f"Saving checkpoint at step: {self.step.item()}")
model = self.diffusion_prior.state_dict(), path = Path(path)
version = __version__, assert not (path.exists() and not overwrite)
step = self.step.item(), path.parent.mkdir(parents = True, exist_ok = True)
**kwargs
)
if self.use_ema: save_obj = dict(
save_obj = {**save_obj, 'ema': self.ema_diffusion_prior.state_dict()} scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(),
model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable
version = version.parse(__version__),
step = self.step.item(),
**kwargs
)
torch.save(save_obj, str(path)) if self.use_ema:
save_obj = {
**save_obj,
'ema': self.ema_diffusion_prior.state_dict(),
'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # save the ema model specifically for easy ema-only reload
}
def load(self, path, only_model = False, strict = True): torch.save(save_obj, str(path))
def load(self, path, overwrite_lr = True, strict = True):
"""
Load a checkpoint of a diffusion prior trainer.
Will load the entire trainer, including the optimizer and EMA.
Params:
- path (str): a path to the DiffusionPriorTrainer checkpoint file
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
Returns:
loaded_obj (dict): The loaded checkpoint dictionary
"""
# all processes need to load checkpoint. no restriction here
path = Path(path) path = Path(path)
assert path.exists() assert path.exists()
loaded_obj = torch.load(str(path)) loaded_obj = torch.load(str(path), map_location=self.device)
if version.parse(__version__) != loaded_obj['version']: if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}') print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict) # unwrap the model when loading from checkpoint
self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
if only_model:
return loaded_obj
self.scaler.load_state_dict(loaded_obj['scaler']) self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer']) self.optimizer.load_state_dict(loaded_obj['optimizer'])
if overwrite_lr:
new_lr = self.optim_kwargs["lr"]
self.print(f"Overriding LR to be {new_lr}")
for group in self.optimizer.param_groups:
group["lr"] = new_lr
if self.use_ema: if self.use_ema:
assert 'ema' in loaded_obj assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict) self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
# below not be necessary, but I had a suspicion that this wasn't being loaded correctly
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
# sync and inform
self.wait_for_everyone()
self.print(f"Loaded model")
return loaded_obj return loaded_obj
# model functionality
def update(self): def update(self):
# only continue with updates until all ranks finish
self.wait_for_everyone()
if exists(self.max_grad_norm): if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer) self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm) # utilize HFA clipping where applicable
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer) self.scaler.step(self.optimizer)
self.scaler.update() self.scaler.update()
@@ -399,17 +503,32 @@ class DiffusionPriorTrainer(nn.Module):
@cast_torch_tensor @cast_torch_tensor
@prior_sample_in_chunks @prior_sample_in_chunks
def p_sample_loop(self, *args, **kwargs): def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs) if self.use_ema:
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
else:
return self.diffusion_prior.p_sample_loop(*args, **kwargs)
@torch.no_grad() @torch.no_grad()
@cast_torch_tensor @cast_torch_tensor
@prior_sample_in_chunks @prior_sample_in_chunks
def sample(self, *args, **kwargs): def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs) if self.use_ema:
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
else:
return self.diffusion_prior.sample(*args, **kwargs)
@torch.no_grad() @torch.no_grad()
def sample_batch_size(self, *args, **kwargs): def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs) if self.use_ema:
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
else:
return self.diffusion_prior.sample_batch_size(*args, **kwargs)
@torch.no_grad()
@cast_torch_tensor
@prior_sample_in_chunks
def embed_text(self, *args, **kwargs):
return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
@cast_torch_tensor @cast_torch_tensor
def forward( def forward(
@@ -427,8 +546,10 @@ class DiffusionPriorTrainer(nn.Module):
total_loss += loss.item() total_loss += loss.item()
# backprop with accelerate if applicable
if self.training: if self.training:
self.scaler.scale(loss).backward() self.backprop(self.scaler.scale(loss))
return total_loss return total_loss

View File

@@ -1,77 +1,136 @@
from pathlib import Path # TODO: add start, num_data_points, eval_every and group to config
# TODO: switch back to repo's wandb
START = 0
NUM_DATA_POINTS = 250e6
EVAL_EVERY = 1000
GROUP = "distributed"
import os
import click import click
import math import wandb
import numpy as np
import torch import torch
import clip
from torch import nn from torch import nn
from torch.nn.functional import normalize
from torch.utils.data import DataLoader
from dalle2_pytorch.dataloaders import make_splits, get_reader import numpy as np
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker from accelerate import Accelerator
from dalle2_pytorch.utils import Timer, print_ribbon
from tqdm import tqdm from dalle2_pytorch.dataloaders import get_reader, make_splits
from dalle2_pytorch.utils import Timer
from dalle2_pytorch.train_configs import (
DiffusionPriorTrainConfig,
TrainDiffusionPriorConfig,
)
from dalle2_pytorch.trackers import BaseTracker, WandbTracker
from dalle2_pytorch import DiffusionPriorTrainer
# constants
REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training # helpers
tracker = WandbTracker()
# helpers functions cos = nn.CosineSimilarity(dim=1, eps=1e-6)
def exists(val): def exists(val):
val is not None return val is not None
# functions
def eval_model(model, dataloader, text_conditioned, loss_type, device, phase="Validation",): def make_model(
model.eval() prior_config, train_config, device: str = None, accelerator: Accelerator = None
):
# create model from config
diffusion_prior = prior_config.create()
# instantiate the trainer
trainer = DiffusionPriorTrainer(
diffusion_prior=diffusion_prior,
lr=train_config.lr,
wd=train_config.wd,
max_grad_norm=train_config.max_grad_norm,
amp=train_config.amp,
use_ema=train_config.use_ema,
device=device,
accelerator=accelerator,
)
return trainer
# eval functions
def eval_model(
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
loss_type: str,
phase: str,
tracker: BaseTracker = None,
use_ema: bool = True,
):
trainer.eval()
if trainer.is_main_process():
click.secho(f"Measuring performance on {phase}", fg="green", blink=True)
with torch.no_grad(): with torch.no_grad():
total_loss = 0. total_loss = 0.0
total_samples = 0. total_samples = 0.0
for image_embeddings, text_data in tqdm(dataloader): for image_embeddings, text_data in dataloader:
image_embeddings = image_embeddings.to(device) image_embeddings = image_embeddings.to(trainer.device)
text_data = text_data.to(device) text_data = text_data.to(trainer.device)
batches = image_embeddings.shape[0] batches = image_embeddings.shape[0]
input_args = dict(image_embed=image_embeddings) input_args = dict(image_embed=image_embeddings)
if text_conditioned: if text_conditioned:
input_args = dict(**input_args, text = text_data) input_args = dict(**input_args, text=text_data)
else: else:
input_args = dict(**input_args, text_embed=text_data) input_args = dict(**input_args, text_embed=text_data)
loss = model(**input_args) if use_ema:
loss = trainer.ema_diffusion_prior(**input_args)
else:
loss = trainer(**input_args)
total_loss += loss * batches total_loss += loss * batches
total_samples += batches total_samples += batches
avg_loss = (total_loss / total_samples) avg_loss = total_loss / total_samples
tracker.log({f'{phase} {loss_type}': avg_loss}) stats = {f"{phase}/{loss_type}": avg_loss}
trainer.print(stats)
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device): if exists(tracker):
diffusion_prior.eval() tracker.log(stats, step=trainer.step.item() + 1)
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
for test_image_embeddings, text_data in tqdm(dataloader): def report_cosine_sims(
test_image_embeddings = test_image_embeddings.to(device) trainer: DiffusionPriorTrainer,
text_data = text_data.to(device) dataloader: DataLoader,
text_conditioned: bool,
tracker: BaseTracker,
phase: str = "validation",
):
trainer.eval()
if trainer.is_main_process():
click.secho("Measuring Cosine-Similarity", fg="green", blink=True)
for test_image_embeddings, text_data in dataloader:
test_image_embeddings = test_image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
# we are text conditioned, we produce an embedding from the tokenized text # we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned: if text_conditioned:
text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text( text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
text_data) text_cond = dict(
text_cond = dict(text_embed=text_embedding, text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
text_encodings=text_encodings, mask=text_mask) )
else: else:
text_embedding = text_data text_embedding = text_data
text_cond = dict(text_embed=text_embedding) text_cond = dict(text_embed=text_embedding)
@@ -82,8 +141,7 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
# roll the text to simulate "unrelated" captions # roll the text to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1) rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx] text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / \ text_embed_shuffled = text_embed_shuffled / normalize(text_embed_shuffled)
text_embed_shuffled.norm(dim=1, keepdim=True)
if text_conditioned: if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx] text_encodings_shuffled = text_encodings[rolled_idx]
@@ -92,294 +150,272 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
text_encodings_shuffled = None text_encodings_shuffled = None
text_mask_shuffled = None text_mask_shuffled = None
text_cond_shuffled = dict(text_embed=text_embed_shuffled, text_cond_shuffled = dict(
text_encodings=text_encodings_shuffled, mask=text_mask_shuffled) text_embed=text_embed_shuffled,
text_encodings=text_encodings_shuffled,
mask=text_mask_shuffled,
)
# prepare the text embedding # prepare the text embedding
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True) text_embed = normalize(text_embedding / text_embedding)
# prepare image embeddings # prepare image embeddings
test_image_embeddings = test_image_embeddings / \ test_image_embeddings = test_image_embeddings / normalize(test_image_embeddings)
test_image_embeddings.norm(dim=1, keepdim=True)
# predict on the unshuffled text embeddings # predict on the unshuffled text embeddings
predicted_image_embeddings = diffusion_prior.p_sample_loop( predicted_image_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, text_cond) test_image_embeddings.shape, text_cond
predicted_image_embeddings = predicted_image_embeddings / \ )
predicted_image_embeddings.norm(dim=1, keepdim=True)
predicted_image_embeddings = predicted_image_embeddings / normalize(
predicted_image_embeddings
)
# predict on the shuffled embeddings # predict on the shuffled embeddings
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop( predicted_unrelated_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, text_cond_shuffled) test_image_embeddings.shape, text_cond_shuffled
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \ )
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / normalize(
predicted_unrelated_embeddings
)
# calculate similarities # calculate similarities
original_similarity = cos( original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy()
text_embed, test_image_embeddings).cpu().numpy() predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy()
predicted_similarity = cos( unrelated_similarity = (
text_embed, predicted_image_embeddings).cpu().numpy() cos(text_embed, predicted_unrelated_embeddings).cpu().numpy()
unrelated_similarity = cos( )
text_embed, predicted_unrelated_embeddings).cpu().numpy() predicted_img_similarity = (
predicted_img_similarity = cos( cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy()
test_image_embeddings, predicted_image_embeddings).cpu().numpy() )
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity), stats = {
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity), f"{phase}/baseline similarity": np.mean(original_similarity),
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity), f"{phase}/similarity with text": np.mean(predicted_similarity),
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)}) f"{phase}/similarity with original image": np.mean(
predicted_img_similarity
),
f"{phase}/similarity with unrelated caption": np.mean(unrelated_similarity),
f"{phase}/difference from baseline similarity": np.mean(
predicted_similarity - original_similarity
),
}
for k, v in stats.items():
trainer.print(f"{phase}/{k}: {v}")
if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1)
# training script
def train(
trainer: DiffusionPriorTrainer,
train_loader: DataLoader,
eval_loader: DataLoader,
test_loader: DataLoader,
config: DiffusionPriorTrainConfig,
):
# distributed tracking with wandb
if trainer.accelerator.num_processes > 1:
os.environ["WANDB_START_METHOD"] = "thread"
tracker = wandb.init(
name=f"RANK:{trainer.device}",
entity=config.tracker.wandb_entity,
project=config.tracker.wandb_project,
config=config.dict(),
group=GROUP,
)
# sync after tracker init
trainer.wait_for_everyone()
# init a timer
timer = Timer()
# do training
for img, txt in train_loader:
trainer.train()
current_step = trainer.step.item() + 1
# place data on device
img = img.to(trainer.device)
txt = txt.to(trainer.device)
# pass to model
loss = trainer(text=txt, image_embed=img)
# display & log loss (will only print from main process)
trainer.print(f"Step {current_step}: Loss {loss}")
# perform backprop & apply EMA updates
trainer.update()
# track samples/sec/rank
samples_per_sec = img.shape[0] / timer.elapsed()
# samples seen
samples_seen = (
config.data.batch_size * trainer.accelerator.num_processes * current_step
)
# ema decay
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
# Log on all processes for debugging
tracker.log(
{
"training/loss": loss,
"samples/sec/rank": samples_per_sec,
"samples/seen": samples_seen,
"ema/decay": ema_decay,
},
step=current_step,
)
# Metric Tracking & Checkpointing (outside of timer's scope)
if current_step % EVAL_EVERY == 0:
eval_model(
trainer,
eval_loader,
config.prior.condition_on_text_encodings,
config.prior.loss_type,
"training/validation",
tracker,
use_ema=False,
)
eval_model(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss=config.prior.loss_type,
phase="ema/validation",
tracker=tracker,
use_ema=True,
)
report_cosine_sims(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
tracker=tracker,
phase="ema/validation",
)
if current_step % config.train.save_every == 0:
trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth")
# reset timer for next round
timer.reset()
# evaluate on test data
eval_model(
trainer=trainer,
dataloader=test_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
phase="test",
tracker=tracker,
)
report_cosine_sims(
trainer,
test_loader,
config.prior.condition_on_text_encodings,
tracker,
phase="test",
)
def initialize_training(config, accelerator=None):
"""
Parse the configuration file, and prepare everything necessary for training
"""
# get a device
if accelerator:
device = accelerator.device
click.secho(f"Accelerating on: {device}", fg="yellow")
else:
if torch.cuda.is_available():
click.secho("GPU detected, defaulting to cuda:0", fg="yellow")
device = "cuda:0"
else:
click.secho("No GPU detected...using cpu", fg="yellow")
device = "cpu"
# make the trainer (will automatically distribute if possible & configured)
trainer = make_model(config.prior, config.train, device, accelerator).to(device)
# reload from chcekpoint
if config.load.resume == True:
click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan")
trainer.load(config.load.source)
# fetch and prepare data
if trainer.is_main_process():
click.secho("Grabbing data from source", fg="blue", blink=True)
img_reader = get_reader(
text_conditioned=trainer.text_conditioned,
img_url=config.data.image_url,
meta_url=config.data.meta_url,
)
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=trainer.text_conditioned,
batch_size=config.data.batch_size,
num_data_points=NUM_DATA_POINTS,
train_split=config.data.splits.train,
eval_split=config.data.splits.val,
image_reader=img_reader,
rank=accelerator.state.process_index if exists(accelerator) else 0,
world_size=accelerator.state.num_processes if exists(accelerator) else 1,
start=START,
)
# wait for everyone to load data before continuing
trainer.wait_for_everyone()
# start training
train(
trainer=trainer,
train_loader=train_loader,
eval_loader=eval_loader,
test_loader=test_loader,
config=config,
)
@click.command() @click.command()
@click.option("--wandb-entity", default="laion") @click.option("--hfa", default=True)
@click.option("--wandb-project", default="diffusion-prior") @click.option("--config_path", default="configs/prior.json")
@click.option("--wandb-dataset", default="LAION-5B") def main(hfa, config_path):
@click.option("--wandb-arch", default="DiffusionPrior") # start HFA if requested
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") if hfa:
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") accelerator = Accelerator()
@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/")
@click.option("--learning-rate", default=1.1e-4)
@click.option("--weight-decay", default=6.02e-2)
@click.option("--dropout", default=5e-2)
@click.option("--max-grad-norm", default=0.5)
@click.option("--num-data-points", default=250e6)
@click.option("--batch-size", default=320)
@click.option("--num-epochs", default=5)
@click.option("--image-embed-dim", default=768)
@click.option("--train-percent", default=0.9)
@click.option("--val-percent", default=1e-7)
@click.option("--test-percent", default=0.0999999)
@click.option("--dpn-depth", default=12)
@click.option("--dpn-dim-head", default=64)
@click.option("--dpn-heads", default=12)
@click.option("--dp-condition-on-text-encodings", default=True)
@click.option("--dp-timesteps", default=1000)
@click.option("--dp-normformer", default=True)
@click.option("--dp-cond-drop-prob", default=0.1)
@click.option("--dp-loss-type", default="l2")
@click.option("--clip", default="ViT-L/14")
@click.option("--amp", default=False)
@click.option("--save-interval", default=120)
@click.option("--save-path", default="./diffusion_prior_checkpoints")
@click.option("--pretrained-model-path", default=None)
@click.option("--gpu-device", default=0)
def train(
wandb_entity,
wandb_project,
wandb_dataset,
wandb_arch,
image_embed_url,
text_embed_url,
meta_url,
learning_rate,
weight_decay,
dropout,
max_grad_norm,
num_data_points,
batch_size,
num_epochs,
image_embed_dim,
train_percent,
val_percent,
test_percent,
dpn_depth,
dpn_dim_head,
dpn_heads,
dp_condition_on_text_encodings,
dp_timesteps,
dp_normformer,
dp_cond_drop_prob,
dp_loss_type,
clip,
amp,
save_interval,
save_path,
pretrained_model_path,
gpu_device
):
config = {
"learning_rate": learning_rate,
"architecture": wandb_arch,
"dataset": wandb_dataset,
"weight_decay": weight_decay,
"max_gradient_clipping_norm": max_grad_norm,
"batch_size": batch_size,
"epochs": num_epochs,
"diffusion_prior_network": {
"depth": dpn_depth,
"dim_head": dpn_dim_head,
"heads": dpn_heads,
"normformer": dp_normformer
},
"diffusion_prior": {
"condition_on_text_encodings": dp_condition_on_text_encodings,
"timesteps": dp_timesteps,
"cond_drop_prob": dp_cond_drop_prob,
"loss_type": dp_loss_type,
"clip": clip
}
}
# Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = pretrained_model_path
RESUME = exists(DPRIOR_PATH)
if not RESUME:
tracker.init(
entity = wandb_entity,
project = wandb_project,
config = config
)
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device(f"cuda:{gpu_device}")
torch.cuda.set_device(device)
# Training loop
# diffusion prior network
prior_network = DiffusionPriorNetwork(
dim = image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
attn_dropout = dropout,
ff_dropout = dropout,
normformer = dp_normformer
)
# Load clip model if text-conditioning
if dp_condition_on_text_encodings:
clip_adapter = OpenAIClipAdapter(clip)
else: else:
clip_adapter = None accelerator = None
# diffusion prior with text embeddings and image embeddings pre-computed # load the configuration file on main process
if not exists(accelerator) or accelerator.is_main_process:
click.secho(f"Loading configuration from {config_path}", fg="green")
diffusion_prior = DiffusionPrior( config = TrainDiffusionPriorConfig.from_json_path(config_path)
net = prior_network,
clip = clip_adapter,
image_embed_dim = image_embed_dim,
timesteps = dp_timesteps,
cond_drop_prob = dp_cond_drop_prob,
loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings
)
# Load pre-trained model from DPRIOR_PATH # send config to get processed
initialize_training(config, accelerator)
if RESUME:
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
# diffusion prior trainer
trainer = DiffusionPriorTrainer(
diffusion_prior = diffusion_prior,
lr = learning_rate,
wd = weight_decay,
max_grad_norm = max_grad_norm,
amp = amp,
).to(device)
# load optimizer and scaler
if RESUME:
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
trainer.scaler.load_state_dict(loaded_obj['scaler'])
# Create save_path if it doesn't exist
Path(save_path).mkdir(exist_ok = True, parents = True)
# Utilize wrapper to abstract away loader logic
print_ribbon("Downloading Embeddings")
reader_args = dict(text_conditioned=dp_condition_on_text_encodings, img_url=image_embed_url)
if dp_condition_on_text_encodings:
reader_args = dict(**reader_args, meta_url=meta_url)
img_reader = get_reader(**reader_args)
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=dp_condition_on_text_encodings,
batch_size=batch_size,
num_data_points=num_data_points,
train_split=train_percent,
eval_split=val_percent,
image_reader=img_reader
)
else:
reader_args = dict(**reader_args, txt_url=text_embed_url)
img_reader, txt_reader = get_reader(**reader_args)
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=dp_condition_on_text_encodings,
batch_size=batch_size,
num_data_points=num_data_points,
train_split=train_percent,
eval_split=val_percent,
image_reader=img_reader,
text_reader=txt_reader
)
### Training code ###
step = 1
timer = Timer()
epochs = num_epochs
for _ in range(epochs):
for image, text in tqdm(train_loader):
diffusion_prior.train()
image = image.to(device)
text = text.to(device)
input_args = dict(image_embed=image)
if dp_condition_on_text_encodings:
input_args = dict(**input_args, text = text)
else:
input_args = dict(**input_args, text_embed=text)
loss = trainer(**input_args)
# Samples per second
samples_per_sec = batch_size * step / timer.elapsed()
# Save checkpoint every save_interval minutes
if(int(timer.elapsed()) >= 60 * save_interval):
timer.reset()
save_diffusion_model(
save_path,
diffusion_prior,
trainer.optimizer,
trainer.scaler,
config,
image_embed_dim)
# Log to wandb
tracker.log({"Training loss": loss,
"Steps": step,
"Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0:
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings, device=device)
### Evaluate model(validation run) ###
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation", device=device)
step += 1
trainer.update()
### Test run ###
eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test")
if __name__ == "__main__": if __name__ == "__main__":
train() main()