mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Distributed Training of the Prior (#112)
* distributed prior trainer better EMA support update load and save methods of prior * update prior training script add test evalution & ema validation add more tracking metrics small cleanup
This commit is contained in:
@@ -14,6 +14,8 @@ from dalle2_pytorch.optimizer import get_optimizer
|
||||
from dalle2_pytorch.version import __version__
|
||||
from packaging import version
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
import numpy as np
|
||||
|
||||
# helper functions
|
||||
@@ -205,7 +207,7 @@ class EMA(nn.Module):
|
||||
self,
|
||||
model,
|
||||
beta = 0.9999,
|
||||
update_after_step = 10000,
|
||||
update_after_step = 100,
|
||||
update_every = 10,
|
||||
inv_gamma = 1.0,
|
||||
power = 2/3,
|
||||
@@ -280,6 +282,7 @@ class EMA(nn.Module):
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.ema_model(*args, **kwargs)
|
||||
|
||||
|
||||
# diffusion prior trainer
|
||||
|
||||
def prior_sample_in_chunks(fn):
|
||||
@@ -303,42 +306,106 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
device = None,
|
||||
accelerator = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
|
||||
assert exists(accelerator) or exists(device), "You must supply some method of obtaining a device."
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
# assign some helpful member vars
|
||||
self.accelerator = accelerator
|
||||
self.device = accelerator.device if exists(accelerator) else device
|
||||
self.text_conditioned = diffusion_prior.condition_on_text_encodings
|
||||
|
||||
# save model
|
||||
|
||||
self.diffusion_prior = diffusion_prior
|
||||
|
||||
# exponential moving average
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
|
||||
|
||||
# optimizer and mixed precision stuff
|
||||
|
||||
self.amp = amp
|
||||
|
||||
self.scaler = GradScaler(enabled = amp)
|
||||
|
||||
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
|
||||
|
||||
self.optimizer = get_optimizer(
|
||||
diffusion_prior.parameters(),
|
||||
lr = lr,
|
||||
wd = wd,
|
||||
eps = eps,
|
||||
group_wd_params = group_wd_params,
|
||||
self.diffusion_prior.parameters(),
|
||||
**self.optim_kwargs,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# distribute the model if using HFA
|
||||
if exists(self.accelerator):
|
||||
self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer)
|
||||
|
||||
# exponential moving average stuff
|
||||
|
||||
self.use_ema = use_ema
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs)
|
||||
|
||||
# gradient clipping if needed
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
# track steps internally
|
||||
|
||||
self.register_buffer('step', torch.tensor([0]))
|
||||
|
||||
# accelerator wrappers
|
||||
|
||||
def print(self, msg):
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.print(msg)
|
||||
else:
|
||||
print(msg)
|
||||
|
||||
def unwrap_model(self, model):
|
||||
if exists(self.accelerator):
|
||||
return self.accelerator.unwrap_model(model)
|
||||
else:
|
||||
return model
|
||||
|
||||
def wait_for_everyone(self):
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
def is_main_process(self):
|
||||
if exists(self.accelerator):
|
||||
return self.accelerator.is_main_process
|
||||
else:
|
||||
return True
|
||||
|
||||
def clip_grad_norm_(self, *args):
|
||||
if exists(self.accelerator):
|
||||
return self.accelerator.clip_grad_norm_(*args)
|
||||
else:
|
||||
return torch.nn.utils.clip_grad_norm_(*args)
|
||||
|
||||
def backprop(self, x):
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.backward(x)
|
||||
else:
|
||||
try:
|
||||
x.backward()
|
||||
except Exception as e:
|
||||
self.print(f"Caught error in backprop call: {e}")
|
||||
|
||||
# utility
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
# ensure we sync gradients before continuing
|
||||
self.wait_for_everyone()
|
||||
|
||||
# only save on the main process
|
||||
if self.is_main_process():
|
||||
self.print(f"Saving checkpoint at step: {self.step.item()}")
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
@@ -346,45 +413,82 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
save_obj = dict(
|
||||
scaler = self.scaler.state_dict(),
|
||||
optimizer = self.optimizer.state_dict(),
|
||||
model = self.diffusion_prior.state_dict(),
|
||||
version = __version__,
|
||||
model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable
|
||||
version = version.parse(__version__),
|
||||
step = self.step.item(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {**save_obj, 'ema': self.ema_diffusion_prior.state_dict()}
|
||||
save_obj = {
|
||||
**save_obj,
|
||||
'ema': self.ema_diffusion_prior.state_dict(),
|
||||
'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # save the ema model specifically for easy ema-only reload
|
||||
}
|
||||
|
||||
torch.save(save_obj, str(path))
|
||||
|
||||
def load(self, path, only_model = False, strict = True):
|
||||
def load(self, path, overwrite_lr = True, strict = True):
|
||||
"""
|
||||
Load a checkpoint of a diffusion prior trainer.
|
||||
|
||||
Will load the entire trainer, including the optimizer and EMA.
|
||||
|
||||
Params:
|
||||
- path (str): a path to the DiffusionPriorTrainer checkpoint file
|
||||
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
|
||||
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
|
||||
|
||||
Returns:
|
||||
loaded_obj (dict): The loaded checkpoint dictionary
|
||||
"""
|
||||
|
||||
# all processes need to load checkpoint. no restriction here
|
||||
path = Path(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded_obj = torch.load(str(path))
|
||||
loaded_obj = torch.load(str(path), map_location=self.device)
|
||||
|
||||
if version.parse(__version__) != loaded_obj['version']:
|
||||
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
|
||||
|
||||
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
|
||||
# unwrap the model when loading from checkpoint
|
||||
self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
if only_model:
|
||||
return loaded_obj
|
||||
|
||||
self.scaler.load_state_dict(loaded_obj['scaler'])
|
||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
|
||||
if overwrite_lr:
|
||||
new_lr = self.optim_kwargs["lr"]
|
||||
|
||||
self.print(f"Overriding LR to be {new_lr}")
|
||||
|
||||
for group in self.optimizer.param_groups:
|
||||
group["lr"] = new_lr
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
# below not be necessary, but I had a suspicion that this wasn't being loaded correctly
|
||||
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
|
||||
|
||||
# sync and inform
|
||||
self.wait_for_everyone()
|
||||
self.print(f"Loaded model")
|
||||
|
||||
return loaded_obj
|
||||
|
||||
# model functionality
|
||||
|
||||
def update(self):
|
||||
# only continue with updates until all ranks finish
|
||||
self.wait_for_everyone()
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
||||
# utilize HFA clipping where applicable
|
||||
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
||||
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
@@ -399,17 +503,32 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def p_sample_loop(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||
else:
|
||||
return self.diffusion_prior.p_sample_loop(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def sample(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||
else:
|
||||
return self.diffusion_prior.sample(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_batch_size(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||
else:
|
||||
return self.diffusion_prior.sample_batch_size(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def embed_text(self, *args, **kwargs):
|
||||
return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
@@ -427,8 +546,10 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
# backprop with accelerate if applicable
|
||||
|
||||
if self.training:
|
||||
self.scaler.scale(loss).backward()
|
||||
self.backprop(self.scaler.scale(loss))
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
@@ -1,77 +1,136 @@
|
||||
from pathlib import Path
|
||||
# TODO: add start, num_data_points, eval_every and group to config
|
||||
# TODO: switch back to repo's wandb
|
||||
|
||||
START = 0
|
||||
NUM_DATA_POINTS = 250e6
|
||||
EVAL_EVERY = 1000
|
||||
GROUP = "distributed"
|
||||
|
||||
import os
|
||||
import click
|
||||
import math
|
||||
import numpy as np
|
||||
import wandb
|
||||
|
||||
import torch
|
||||
import clip
|
||||
from torch import nn
|
||||
from torch.nn.functional import normalize
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dalle2_pytorch.dataloaders import make_splits, get_reader
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
|
||||
import numpy as np
|
||||
|
||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||
from accelerate import Accelerator
|
||||
|
||||
from tqdm import tqdm
|
||||
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
||||
from dalle2_pytorch.utils import Timer
|
||||
from dalle2_pytorch.train_configs import (
|
||||
DiffusionPriorTrainConfig,
|
||||
TrainDiffusionPriorConfig,
|
||||
)
|
||||
from dalle2_pytorch.trackers import BaseTracker, WandbTracker
|
||||
from dalle2_pytorch import DiffusionPriorTrainer
|
||||
|
||||
# constants
|
||||
|
||||
REPORT_METRICS_EVERY = 250 # for cosine similarity and other metric reporting during training
|
||||
# helpers
|
||||
|
||||
tracker = WandbTracker()
|
||||
|
||||
# helpers functions
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
|
||||
def exists(val):
|
||||
val is not None
|
||||
return val is not None
|
||||
|
||||
# functions
|
||||
|
||||
def eval_model(model, dataloader, text_conditioned, loss_type, device, phase="Validation",):
|
||||
model.eval()
|
||||
def make_model(
|
||||
prior_config, train_config, device: str = None, accelerator: Accelerator = None
|
||||
):
|
||||
# create model from config
|
||||
diffusion_prior = prior_config.create()
|
||||
|
||||
# instantiate the trainer
|
||||
trainer = DiffusionPriorTrainer(
|
||||
diffusion_prior=diffusion_prior,
|
||||
lr=train_config.lr,
|
||||
wd=train_config.wd,
|
||||
max_grad_norm=train_config.max_grad_norm,
|
||||
amp=train_config.amp,
|
||||
use_ema=train_config.use_ema,
|
||||
device=device,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
# eval functions
|
||||
|
||||
|
||||
def eval_model(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
dataloader: DataLoader,
|
||||
text_conditioned: bool,
|
||||
loss_type: str,
|
||||
phase: str,
|
||||
tracker: BaseTracker = None,
|
||||
use_ema: bool = True,
|
||||
):
|
||||
trainer.eval()
|
||||
if trainer.is_main_process():
|
||||
click.secho(f"Measuring performance on {phase}", fg="green", blink=True)
|
||||
|
||||
with torch.no_grad():
|
||||
total_loss = 0.
|
||||
total_samples = 0.
|
||||
total_loss = 0.0
|
||||
total_samples = 0.0
|
||||
|
||||
for image_embeddings, text_data in tqdm(dataloader):
|
||||
image_embeddings = image_embeddings.to(device)
|
||||
text_data = text_data.to(device)
|
||||
for image_embeddings, text_data in dataloader:
|
||||
image_embeddings = image_embeddings.to(trainer.device)
|
||||
text_data = text_data.to(trainer.device)
|
||||
|
||||
batches = image_embeddings.shape[0]
|
||||
|
||||
input_args = dict(image_embed=image_embeddings)
|
||||
|
||||
if text_conditioned:
|
||||
input_args = dict(**input_args, text = text_data)
|
||||
input_args = dict(**input_args, text=text_data)
|
||||
else:
|
||||
input_args = dict(**input_args, text_embed=text_data)
|
||||
|
||||
loss = model(**input_args)
|
||||
if use_ema:
|
||||
loss = trainer.ema_diffusion_prior(**input_args)
|
||||
else:
|
||||
loss = trainer(**input_args)
|
||||
|
||||
total_loss += loss * batches
|
||||
total_samples += batches
|
||||
|
||||
avg_loss = (total_loss / total_samples)
|
||||
avg_loss = total_loss / total_samples
|
||||
|
||||
tracker.log({f'{phase} {loss_type}': avg_loss})
|
||||
stats = {f"{phase}/{loss_type}": avg_loss}
|
||||
trainer.print(stats)
|
||||
|
||||
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
|
||||
diffusion_prior.eval()
|
||||
if exists(tracker):
|
||||
tracker.log(stats, step=trainer.step.item() + 1)
|
||||
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
for test_image_embeddings, text_data in tqdm(dataloader):
|
||||
test_image_embeddings = test_image_embeddings.to(device)
|
||||
text_data = text_data.to(device)
|
||||
def report_cosine_sims(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
dataloader: DataLoader,
|
||||
text_conditioned: bool,
|
||||
tracker: BaseTracker,
|
||||
phase: str = "validation",
|
||||
):
|
||||
trainer.eval()
|
||||
if trainer.is_main_process():
|
||||
click.secho("Measuring Cosine-Similarity", fg="green", blink=True)
|
||||
|
||||
for test_image_embeddings, text_data in dataloader:
|
||||
test_image_embeddings = test_image_embeddings.to(trainer.device)
|
||||
text_data = text_data.to(trainer.device)
|
||||
|
||||
# we are text conditioned, we produce an embedding from the tokenized text
|
||||
if text_conditioned:
|
||||
text_embedding, text_encodings, text_mask = diffusion_prior.clip.embed_text(
|
||||
text_data)
|
||||
text_cond = dict(text_embed=text_embedding,
|
||||
text_encodings=text_encodings, mask=text_mask)
|
||||
text_embedding, text_encodings, text_mask = trainer.embed_text(text_data)
|
||||
text_cond = dict(
|
||||
text_embed=text_embedding, text_encodings=text_encodings, mask=text_mask
|
||||
)
|
||||
else:
|
||||
text_embedding = text_data
|
||||
text_cond = dict(text_embed=text_embedding)
|
||||
@@ -82,8 +141,7 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
|
||||
# roll the text to simulate "unrelated" captions
|
||||
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
|
||||
text_embed_shuffled = text_embed_shuffled[rolled_idx]
|
||||
text_embed_shuffled = text_embed_shuffled / \
|
||||
text_embed_shuffled.norm(dim=1, keepdim=True)
|
||||
text_embed_shuffled = text_embed_shuffled / normalize(text_embed_shuffled)
|
||||
|
||||
if text_conditioned:
|
||||
text_encodings_shuffled = text_encodings[rolled_idx]
|
||||
@@ -92,294 +150,272 @@ def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
|
||||
text_encodings_shuffled = None
|
||||
text_mask_shuffled = None
|
||||
|
||||
text_cond_shuffled = dict(text_embed=text_embed_shuffled,
|
||||
text_encodings=text_encodings_shuffled, mask=text_mask_shuffled)
|
||||
text_cond_shuffled = dict(
|
||||
text_embed=text_embed_shuffled,
|
||||
text_encodings=text_encodings_shuffled,
|
||||
mask=text_mask_shuffled,
|
||||
)
|
||||
|
||||
# prepare the text embedding
|
||||
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
|
||||
text_embed = normalize(text_embedding / text_embedding)
|
||||
|
||||
# prepare image embeddings
|
||||
test_image_embeddings = test_image_embeddings / \
|
||||
test_image_embeddings.norm(dim=1, keepdim=True)
|
||||
test_image_embeddings = test_image_embeddings / normalize(test_image_embeddings)
|
||||
|
||||
# predict on the unshuffled text embeddings
|
||||
predicted_image_embeddings = diffusion_prior.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond)
|
||||
predicted_image_embeddings = predicted_image_embeddings / \
|
||||
predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
predicted_image_embeddings = trainer.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond
|
||||
)
|
||||
|
||||
predicted_image_embeddings = predicted_image_embeddings / normalize(
|
||||
predicted_image_embeddings
|
||||
)
|
||||
|
||||
# predict on the shuffled embeddings
|
||||
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond_shuffled)
|
||||
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
|
||||
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
|
||||
predicted_unrelated_embeddings = trainer.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond_shuffled
|
||||
)
|
||||
|
||||
predicted_unrelated_embeddings = predicted_unrelated_embeddings / normalize(
|
||||
predicted_unrelated_embeddings
|
||||
)
|
||||
|
||||
# calculate similarities
|
||||
original_similarity = cos(
|
||||
text_embed, test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(
|
||||
text_embed, predicted_image_embeddings).cpu().numpy()
|
||||
unrelated_similarity = cos(
|
||||
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||
predicted_img_similarity = cos(
|
||||
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||
tracker.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
|
||||
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
|
||||
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
|
||||
original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy()
|
||||
unrelated_similarity = (
|
||||
cos(text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||
)
|
||||
predicted_img_similarity = (
|
||||
cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||
)
|
||||
|
||||
stats = {
|
||||
f"{phase}/baseline similarity": np.mean(original_similarity),
|
||||
f"{phase}/similarity with text": np.mean(predicted_similarity),
|
||||
f"{phase}/similarity with original image": np.mean(
|
||||
predicted_img_similarity
|
||||
),
|
||||
f"{phase}/similarity with unrelated caption": np.mean(unrelated_similarity),
|
||||
f"{phase}/difference from baseline similarity": np.mean(
|
||||
predicted_similarity - original_similarity
|
||||
),
|
||||
}
|
||||
|
||||
for k, v in stats.items():
|
||||
trainer.print(f"{phase}/{k}: {v}")
|
||||
|
||||
if exists(tracker):
|
||||
tracker.log(stats, step=trainer.step.item() + 1)
|
||||
|
||||
|
||||
# training script
|
||||
|
||||
|
||||
def train(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
train_loader: DataLoader,
|
||||
eval_loader: DataLoader,
|
||||
test_loader: DataLoader,
|
||||
config: DiffusionPriorTrainConfig,
|
||||
):
|
||||
# distributed tracking with wandb
|
||||
if trainer.accelerator.num_processes > 1:
|
||||
os.environ["WANDB_START_METHOD"] = "thread"
|
||||
|
||||
tracker = wandb.init(
|
||||
name=f"RANK:{trainer.device}",
|
||||
entity=config.tracker.wandb_entity,
|
||||
project=config.tracker.wandb_project,
|
||||
config=config.dict(),
|
||||
group=GROUP,
|
||||
)
|
||||
|
||||
# sync after tracker init
|
||||
trainer.wait_for_everyone()
|
||||
|
||||
# init a timer
|
||||
timer = Timer()
|
||||
|
||||
# do training
|
||||
for img, txt in train_loader:
|
||||
trainer.train()
|
||||
current_step = trainer.step.item() + 1
|
||||
|
||||
# place data on device
|
||||
img = img.to(trainer.device)
|
||||
txt = txt.to(trainer.device)
|
||||
|
||||
# pass to model
|
||||
loss = trainer(text=txt, image_embed=img)
|
||||
|
||||
# display & log loss (will only print from main process)
|
||||
trainer.print(f"Step {current_step}: Loss {loss}")
|
||||
|
||||
# perform backprop & apply EMA updates
|
||||
trainer.update()
|
||||
|
||||
# track samples/sec/rank
|
||||
samples_per_sec = img.shape[0] / timer.elapsed()
|
||||
|
||||
# samples seen
|
||||
samples_seen = (
|
||||
config.data.batch_size * trainer.accelerator.num_processes * current_step
|
||||
)
|
||||
|
||||
# ema decay
|
||||
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
|
||||
|
||||
# Log on all processes for debugging
|
||||
tracker.log(
|
||||
{
|
||||
"training/loss": loss,
|
||||
"samples/sec/rank": samples_per_sec,
|
||||
"samples/seen": samples_seen,
|
||||
"ema/decay": ema_decay,
|
||||
},
|
||||
step=current_step,
|
||||
)
|
||||
|
||||
# Metric Tracking & Checkpointing (outside of timer's scope)
|
||||
if current_step % EVAL_EVERY == 0:
|
||||
eval_model(
|
||||
trainer,
|
||||
eval_loader,
|
||||
config.prior.condition_on_text_encodings,
|
||||
config.prior.loss_type,
|
||||
"training/validation",
|
||||
tracker,
|
||||
use_ema=False,
|
||||
)
|
||||
|
||||
eval_model(
|
||||
trainer=trainer,
|
||||
dataloader=eval_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
loss=config.prior.loss_type,
|
||||
phase="ema/validation",
|
||||
tracker=tracker,
|
||||
use_ema=True,
|
||||
)
|
||||
|
||||
report_cosine_sims(
|
||||
trainer=trainer,
|
||||
dataloader=eval_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
tracker=tracker,
|
||||
phase="ema/validation",
|
||||
)
|
||||
|
||||
if current_step % config.train.save_every == 0:
|
||||
trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth")
|
||||
|
||||
# reset timer for next round
|
||||
timer.reset()
|
||||
|
||||
# evaluate on test data
|
||||
|
||||
eval_model(
|
||||
trainer=trainer,
|
||||
dataloader=test_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
loss_type=config.prior.loss_type,
|
||||
phase="test",
|
||||
tracker=tracker,
|
||||
)
|
||||
|
||||
report_cosine_sims(
|
||||
trainer,
|
||||
test_loader,
|
||||
config.prior.condition_on_text_encodings,
|
||||
tracker,
|
||||
phase="test",
|
||||
)
|
||||
|
||||
|
||||
def initialize_training(config, accelerator=None):
|
||||
"""
|
||||
Parse the configuration file, and prepare everything necessary for training
|
||||
"""
|
||||
|
||||
# get a device
|
||||
|
||||
if accelerator:
|
||||
device = accelerator.device
|
||||
click.secho(f"Accelerating on: {device}", fg="yellow")
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
click.secho("GPU detected, defaulting to cuda:0", fg="yellow")
|
||||
device = "cuda:0"
|
||||
else:
|
||||
click.secho("No GPU detected...using cpu", fg="yellow")
|
||||
device = "cpu"
|
||||
|
||||
# make the trainer (will automatically distribute if possible & configured)
|
||||
|
||||
trainer = make_model(config.prior, config.train, device, accelerator).to(device)
|
||||
|
||||
# reload from chcekpoint
|
||||
|
||||
if config.load.resume == True:
|
||||
click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan")
|
||||
trainer.load(config.load.source)
|
||||
|
||||
# fetch and prepare data
|
||||
|
||||
if trainer.is_main_process():
|
||||
click.secho("Grabbing data from source", fg="blue", blink=True)
|
||||
|
||||
img_reader = get_reader(
|
||||
text_conditioned=trainer.text_conditioned,
|
||||
img_url=config.data.image_url,
|
||||
meta_url=config.data.meta_url,
|
||||
)
|
||||
|
||||
train_loader, eval_loader, test_loader = make_splits(
|
||||
text_conditioned=trainer.text_conditioned,
|
||||
batch_size=config.data.batch_size,
|
||||
num_data_points=NUM_DATA_POINTS,
|
||||
train_split=config.data.splits.train,
|
||||
eval_split=config.data.splits.val,
|
||||
image_reader=img_reader,
|
||||
rank=accelerator.state.process_index if exists(accelerator) else 0,
|
||||
world_size=accelerator.state.num_processes if exists(accelerator) else 1,
|
||||
start=START,
|
||||
)
|
||||
|
||||
# wait for everyone to load data before continuing
|
||||
trainer.wait_for_everyone()
|
||||
|
||||
# start training
|
||||
train(
|
||||
trainer=trainer,
|
||||
train_loader=train_loader,
|
||||
eval_loader=eval_loader,
|
||||
test_loader=test_loader,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--wandb-entity", default="laion")
|
||||
@click.option("--wandb-project", default="diffusion-prior")
|
||||
@click.option("--wandb-dataset", default="LAION-5B")
|
||||
@click.option("--wandb-arch", default="DiffusionPrior")
|
||||
@click.option("--image-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
||||
@click.option("--text-embed-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
||||
@click.option("--meta-url", default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/")
|
||||
@click.option("--learning-rate", default=1.1e-4)
|
||||
@click.option("--weight-decay", default=6.02e-2)
|
||||
@click.option("--dropout", default=5e-2)
|
||||
@click.option("--max-grad-norm", default=0.5)
|
||||
@click.option("--num-data-points", default=250e6)
|
||||
@click.option("--batch-size", default=320)
|
||||
@click.option("--num-epochs", default=5)
|
||||
@click.option("--image-embed-dim", default=768)
|
||||
@click.option("--train-percent", default=0.9)
|
||||
@click.option("--val-percent", default=1e-7)
|
||||
@click.option("--test-percent", default=0.0999999)
|
||||
@click.option("--dpn-depth", default=12)
|
||||
@click.option("--dpn-dim-head", default=64)
|
||||
@click.option("--dpn-heads", default=12)
|
||||
@click.option("--dp-condition-on-text-encodings", default=True)
|
||||
@click.option("--dp-timesteps", default=1000)
|
||||
@click.option("--dp-normformer", default=True)
|
||||
@click.option("--dp-cond-drop-prob", default=0.1)
|
||||
@click.option("--dp-loss-type", default="l2")
|
||||
@click.option("--clip", default="ViT-L/14")
|
||||
@click.option("--amp", default=False)
|
||||
@click.option("--save-interval", default=120)
|
||||
@click.option("--save-path", default="./diffusion_prior_checkpoints")
|
||||
@click.option("--pretrained-model-path", default=None)
|
||||
@click.option("--gpu-device", default=0)
|
||||
def train(
|
||||
wandb_entity,
|
||||
wandb_project,
|
||||
wandb_dataset,
|
||||
wandb_arch,
|
||||
image_embed_url,
|
||||
text_embed_url,
|
||||
meta_url,
|
||||
learning_rate,
|
||||
weight_decay,
|
||||
dropout,
|
||||
max_grad_norm,
|
||||
num_data_points,
|
||||
batch_size,
|
||||
num_epochs,
|
||||
image_embed_dim,
|
||||
train_percent,
|
||||
val_percent,
|
||||
test_percent,
|
||||
dpn_depth,
|
||||
dpn_dim_head,
|
||||
dpn_heads,
|
||||
dp_condition_on_text_encodings,
|
||||
dp_timesteps,
|
||||
dp_normformer,
|
||||
dp_cond_drop_prob,
|
||||
dp_loss_type,
|
||||
clip,
|
||||
amp,
|
||||
save_interval,
|
||||
save_path,
|
||||
pretrained_model_path,
|
||||
gpu_device
|
||||
):
|
||||
config = {
|
||||
"learning_rate": learning_rate,
|
||||
"architecture": wandb_arch,
|
||||
"dataset": wandb_dataset,
|
||||
"weight_decay": weight_decay,
|
||||
"max_gradient_clipping_norm": max_grad_norm,
|
||||
"batch_size": batch_size,
|
||||
"epochs": num_epochs,
|
||||
"diffusion_prior_network": {
|
||||
"depth": dpn_depth,
|
||||
"dim_head": dpn_dim_head,
|
||||
"heads": dpn_heads,
|
||||
"normformer": dp_normformer
|
||||
},
|
||||
"diffusion_prior": {
|
||||
"condition_on_text_encodings": dp_condition_on_text_encodings,
|
||||
"timesteps": dp_timesteps,
|
||||
"cond_drop_prob": dp_cond_drop_prob,
|
||||
"loss_type": dp_loss_type,
|
||||
"clip": clip
|
||||
}
|
||||
}
|
||||
|
||||
# Check if DPRIOR_PATH exists(saved model path)
|
||||
|
||||
DPRIOR_PATH = pretrained_model_path
|
||||
RESUME = exists(DPRIOR_PATH)
|
||||
|
||||
if not RESUME:
|
||||
tracker.init(
|
||||
entity = wandb_entity,
|
||||
project = wandb_project,
|
||||
config = config
|
||||
)
|
||||
|
||||
# Obtain the utilized device.
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
if has_cuda:
|
||||
device = torch.device(f"cuda:{gpu_device}")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Training loop
|
||||
# diffusion prior network
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
heads = dpn_heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
normformer = dp_normformer
|
||||
)
|
||||
|
||||
# Load clip model if text-conditioning
|
||||
if dp_condition_on_text_encodings:
|
||||
clip_adapter = OpenAIClipAdapter(clip)
|
||||
@click.option("--hfa", default=True)
|
||||
@click.option("--config_path", default="configs/prior.json")
|
||||
def main(hfa, config_path):
|
||||
# start HFA if requested
|
||||
if hfa:
|
||||
accelerator = Accelerator()
|
||||
else:
|
||||
clip_adapter = None
|
||||
accelerator = None
|
||||
|
||||
# diffusion prior with text embeddings and image embeddings pre-computed
|
||||
# load the configuration file on main process
|
||||
if not exists(accelerator) or accelerator.is_main_process:
|
||||
click.secho(f"Loading configuration from {config_path}", fg="green")
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip_adapter,
|
||||
image_embed_dim = image_embed_dim,
|
||||
timesteps = dp_timesteps,
|
||||
cond_drop_prob = dp_cond_drop_prob,
|
||||
loss_type = dp_loss_type,
|
||||
condition_on_text_encodings = dp_condition_on_text_encodings
|
||||
)
|
||||
config = TrainDiffusionPriorConfig.from_json_path(config_path)
|
||||
|
||||
# Load pre-trained model from DPRIOR_PATH
|
||||
|
||||
if RESUME:
|
||||
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
|
||||
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
|
||||
|
||||
# diffusion prior trainer
|
||||
|
||||
trainer = DiffusionPriorTrainer(
|
||||
diffusion_prior = diffusion_prior,
|
||||
lr = learning_rate,
|
||||
wd = weight_decay,
|
||||
max_grad_norm = max_grad_norm,
|
||||
amp = amp,
|
||||
).to(device)
|
||||
|
||||
# load optimizer and scaler
|
||||
|
||||
if RESUME:
|
||||
trainer.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
trainer.scaler.load_state_dict(loaded_obj['scaler'])
|
||||
|
||||
# Create save_path if it doesn't exist
|
||||
|
||||
Path(save_path).mkdir(exist_ok = True, parents = True)
|
||||
|
||||
# Utilize wrapper to abstract away loader logic
|
||||
print_ribbon("Downloading Embeddings")
|
||||
reader_args = dict(text_conditioned=dp_condition_on_text_encodings, img_url=image_embed_url)
|
||||
|
||||
if dp_condition_on_text_encodings:
|
||||
reader_args = dict(**reader_args, meta_url=meta_url)
|
||||
img_reader = get_reader(**reader_args)
|
||||
train_loader, eval_loader, test_loader = make_splits(
|
||||
text_conditioned=dp_condition_on_text_encodings,
|
||||
batch_size=batch_size,
|
||||
num_data_points=num_data_points,
|
||||
train_split=train_percent,
|
||||
eval_split=val_percent,
|
||||
image_reader=img_reader
|
||||
)
|
||||
else:
|
||||
reader_args = dict(**reader_args, txt_url=text_embed_url)
|
||||
img_reader, txt_reader = get_reader(**reader_args)
|
||||
train_loader, eval_loader, test_loader = make_splits(
|
||||
text_conditioned=dp_condition_on_text_encodings,
|
||||
batch_size=batch_size,
|
||||
num_data_points=num_data_points,
|
||||
train_split=train_percent,
|
||||
eval_split=val_percent,
|
||||
image_reader=img_reader,
|
||||
text_reader=txt_reader
|
||||
)
|
||||
|
||||
### Training code ###
|
||||
|
||||
step = 1
|
||||
timer = Timer()
|
||||
epochs = num_epochs
|
||||
|
||||
for _ in range(epochs):
|
||||
|
||||
for image, text in tqdm(train_loader):
|
||||
diffusion_prior.train()
|
||||
|
||||
image = image.to(device)
|
||||
text = text.to(device)
|
||||
|
||||
input_args = dict(image_embed=image)
|
||||
if dp_condition_on_text_encodings:
|
||||
input_args = dict(**input_args, text = text)
|
||||
else:
|
||||
input_args = dict(**input_args, text_embed=text)
|
||||
|
||||
loss = trainer(**input_args)
|
||||
|
||||
# Samples per second
|
||||
|
||||
samples_per_sec = batch_size * step / timer.elapsed()
|
||||
|
||||
# Save checkpoint every save_interval minutes
|
||||
if(int(timer.elapsed()) >= 60 * save_interval):
|
||||
timer.reset()
|
||||
|
||||
save_diffusion_model(
|
||||
save_path,
|
||||
diffusion_prior,
|
||||
trainer.optimizer,
|
||||
trainer.scaler,
|
||||
config,
|
||||
image_embed_dim)
|
||||
|
||||
# Log to wandb
|
||||
tracker.log({"Training loss": loss,
|
||||
"Steps": step,
|
||||
"Samples per second": samples_per_sec})
|
||||
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
|
||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||
# Get embeddings from the most recently saved model
|
||||
if(step % REPORT_METRICS_EVERY) == 0:
|
||||
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings, device=device)
|
||||
### Evaluate model(validation run) ###
|
||||
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation", device=device)
|
||||
|
||||
step += 1
|
||||
trainer.update()
|
||||
|
||||
### Test run ###
|
||||
eval_model(diffusion_prior, test_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Test")
|
||||
# send config to get processed
|
||||
initialize_training(config, accelerator)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user