Prior updates (#211)

* update configs for prior

add prior warmup to config

update example prior config

* update prior trainer & script

add deepspeed amp & warmup

adopt full accelerator support

reload at sample point

finish epoch resume code

* update tracker save method for prior

* helper functions for prior_loader
This commit is contained in:
zion
2022-07-20 18:04:26 -07:00
committed by GitHub
parent 06c65b60d2
commit f9423d308b
6 changed files with 676 additions and 352 deletions

View File

@@ -1,18 +1,14 @@
{ {
"prior": { "prior": {
"clip": { "clip": {
"make": "x-clip", "make": "openai",
"model": "ViT-L/14", "model": "ViT-L/14"
"base_model_kwargs": {
"dim_text": 768,
"dim_image": 768,
"dim_latent": 768
}
}, },
"net": { "net": {
"dim": 768, "dim": 768,
"depth": 12, "depth": 12,
"num_timesteps": 1000, "num_timesteps": 1000,
"max_text_len": 77,
"num_time_embeds": 1, "num_time_embeds": 1,
"num_image_embeds": 1, "num_image_embeds": 1,
"num_text_embeds": 1, "num_text_embeds": 1,
@@ -20,8 +16,8 @@
"heads": 12, "heads": 12,
"ff_mult": 4, "ff_mult": 4,
"norm_out": true, "norm_out": true,
"attn_dropout": 0.0, "attn_dropout": 0.05,
"ff_dropout": 0.0, "ff_dropout": 0.05,
"final_proj": true, "final_proj": true,
"normformer": true, "normformer": true,
"rotary_emb": true "rotary_emb": true
@@ -30,6 +26,7 @@
"image_size": 224, "image_size": 224,
"image_channels": 3, "image_channels": 3,
"timesteps": 1000, "timesteps": 1000,
"sample_timesteps": 64,
"cond_drop_prob": 0.1, "cond_drop_prob": 0.1,
"loss_type": "l2", "loss_type": "l2",
"predict_x_start": true, "predict_x_start": true,
@@ -37,34 +34,48 @@
"condition_on_text_encodings": true "condition_on_text_encodings": true
}, },
"data": { "data": {
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/", "batch_size": 128,
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/", "num_data_points": 100000,
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/", "eval_every_seconds": 1600,
"batch_size": 256, "image_url": "<path to your images>",
"meta_url": "<path to your metadata>",
"splits": { "splits": {
"train": 0.9, "train": 0.8,
"val": 1e-7, "val": 0.1,
"test": 0.0999999 "test": 0.1
} }
}, },
"train": { "train": {
"epochs": 1, "epochs": 5,
"lr": 1.1e-4, "lr": 1.1e-4,
"wd": 6.02e-2, "wd": 6.02e-2,
"max_grad_norm": 0.5, "max_grad_norm": 0.5,
"use_ema": true, "use_ema": true,
"ema_beta": 0.9999,
"ema_update_after_step": 50,
"warmup_steps": 50,
"amp": false, "amp": false,
"save_every": 10000 "save_every_seconds": 3600,
}, "eval_timesteps": [64, 1000],
"load": { "random_seed": 84513
"source": null,
"resume": false
}, },
"tracker": { "tracker": {
"tracker_type": "wandb", "data_path": ".prior",
"data_path": "./prior_checkpoints", "overwrite_data_path": true,
"wandb_entity": "laion", "log": {
"wandb_project": "diffusion-prior", "log_type": "wandb",
"verbose": true "wandb_entity": "<your wandb username>",
"wandb_project": "prior_debugging",
"wandb_resume": false,
"verbose": true
},
"save": [
{
"save_to": "local",
"save_type": "checkpoint",
"save_latest_to": ".prior/latest_checkpoint.pth",
"save_best_to": ".prior/best_checkpoint.pth"
}
]
} }
} }

View File

@@ -67,6 +67,15 @@ class PriorEmbeddingDataset(IterableDataset):
def __str__(self): def __str__(self):
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>" return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
def set_start(self, start):
"""
Adjust the starting point within the reader, useful for resuming an epoch
"""
self.start = start
def get_start(self):
return self.start
def get_sample(self): def get_sample(self):
""" """
pre-proocess data from either reader into a common format pre-proocess data from either reader into a common format

View File

@@ -528,12 +528,8 @@ class Tracker:
elif save_type == 'model': elif save_type == 'model':
if isinstance(trainer, DiffusionPriorTrainer): if isinstance(trainer, DiffusionPriorTrainer):
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
prior: DiffusionPrior = trainer.unwrap_model(prior) state_dict = trainer.accelerator.unwrap_model(prior).state_dict()
# Remove CLIP if it is part of the model torch.save(state_dict, file_path)
original_clip = prior.clip
prior.clip = None
model_state_dict = prior.state_dict()
prior.clip = original_clip
elif isinstance(trainer, DecoderTrainer): elif isinstance(trainer, DecoderTrainer):
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder) decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
# Remove CLIP if it is part of the model # Remove CLIP if it is part of the model

View File

@@ -145,6 +145,9 @@ class DiffusionPriorNetworkConfig(BaseModel):
normformer: bool = False normformer: bool = False
rotary_emb: bool = True rotary_emb: bool = True
class Config:
extra = "allow"
def create(self): def create(self):
kwargs = self.dict() kwargs = self.dict()
return DiffusionPriorNetwork(**kwargs) return DiffusionPriorNetwork(**kwargs)
@@ -187,23 +190,26 @@ class DiffusionPriorTrainConfig(BaseModel):
use_ema: bool = True use_ema: bool = True
ema_beta: float = 0.99 ema_beta: float = 0.99
amp: bool = False amp: bool = False
save_every: int = 10000 # what steps to save on warmup_steps: int = None # number of warmup steps
save_every_seconds: int = 3600 # how often to save
eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with
best_validation_loss: float = 1e9 # the current best valudation loss observed
current_epoch: int = 0 # the current epoch
num_samples_seen: int = 0 # the current number of samples seen
random_seed: int = 0 # manual seed for torch
class DiffusionPriorDataConfig(BaseModel): class DiffusionPriorDataConfig(BaseModel):
image_url: str # path to embeddings folder image_url: str # path to embeddings folder
meta_url: str # path to metadata (captions) for images meta_url: str # path to metadata (captions) for images
splits: TrainSplitConfig splits: TrainSplitConfig # define train, validation, test splits for your dataset
batch_size: int = 64 batch_size: int # per-gpu batch size used to train the model
num_data_points: int = 25e7 # total number of datapoints to train on
class DiffusionPriorLoadConfig(BaseModel): eval_every_seconds: int = 3600 # validation statistics will be performed this often
source: str = None
resume: bool = False
class TrainDiffusionPriorConfig(BaseModel): class TrainDiffusionPriorConfig(BaseModel):
prior: DiffusionPriorConfig prior: DiffusionPriorConfig
data: DiffusionPriorDataConfig data: DiffusionPriorDataConfig
train: DiffusionPriorTrainConfig train: DiffusionPriorTrainConfig
load: DiffusionPriorLoadConfig
tracker: TrackerConfig tracker: TrackerConfig
@classmethod @classmethod
@@ -323,12 +329,6 @@ class DecoderEvaluateConfig(BaseModel):
KID: Dict[str, Any] = None KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None LPIPS: Dict[str, Any] = None
class DecoderLoadConfig(BaseModel):
source: str = None # Supports file and wandb
run_path: str = '' # Used only if source is wandb
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
resume: bool = False # If using wandb, whether to resume the run
class TrainDecoderConfig(BaseModel): class TrainDecoderConfig(BaseModel):
decoder: DecoderConfig decoder: DecoderConfig
data: DecoderDataConfig data: DecoderDataConfig

View File

@@ -174,27 +174,21 @@ class DiffusionPriorTrainer(nn.Module):
def __init__( def __init__(
self, self,
diffusion_prior, diffusion_prior,
accelerator,
use_ema = True, use_ema = True,
lr = 3e-4, lr = 3e-4,
wd = 1e-2, wd = 1e-2,
eps = 1e-6, eps = 1e-6,
max_grad_norm = None, max_grad_norm = None,
amp = False,
group_wd_params = True, group_wd_params = True,
device = None, warmup_steps = 1,
accelerator = None,
verbose = True,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior) assert isinstance(diffusion_prior, DiffusionPrior)
assert not exists(accelerator) or isinstance(accelerator, Accelerator) assert isinstance(accelerator, Accelerator)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
# verbosity
self.verbose = verbose
# assign some helpful member vars # assign some helpful member vars
self.accelerator = accelerator self.accelerator = accelerator
@@ -202,23 +196,31 @@ class DiffusionPriorTrainer(nn.Module):
# setting the device # setting the device
if not exists(accelerator) and not exists(device): self.device = accelerator.device
diffusion_prior_device = next(diffusion_prior.parameters()).device diffusion_prior.to(self.device)
self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}')
self.device = diffusion_prior_device
else:
self.device = accelerator.device if exists(accelerator) else device
diffusion_prior.to(self.device)
# save model # save model
self.diffusion_prior = diffusion_prior self.diffusion_prior = diffusion_prior
# optimizer and mixed precision stuff # mixed precision checks
self.amp = amp if (
exists(self.accelerator)
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
and self.diffusion_prior.clip is not None
):
# Then we need to make sure clip is using the correct precision or else deepspeed will error
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
self.diffusion_prior.clip.to(precision_type)
self.scaler = GradScaler(enabled = amp) # optimizer stuff
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params) self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
@@ -227,17 +229,21 @@ class DiffusionPriorTrainer(nn.Module):
**self.optim_kwargs, **self.optim_kwargs,
**kwargs **kwargs
) )
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
# distribute the model if using HFA # distribute the model if using HFA
if exists(self.accelerator):
self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer) self.diffusion_prior, self.optimizer, self.scheduler = self.accelerator.prepare(self.diffusion_prior, self.optimizer, self.scheduler)
# exponential moving average stuff # exponential moving average stuff
self.use_ema = use_ema self.use_ema = use_ema
if self.use_ema: if self.use_ema:
self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs) self.ema_diffusion_prior = EMA(self.accelerator.unwrap_model(self.diffusion_prior), **ema_kwargs)
# gradient clipping if needed # gradient clipping if needed
@@ -247,67 +253,24 @@ class DiffusionPriorTrainer(nn.Module):
self.register_buffer('step', torch.tensor([0], device = self.device)) self.register_buffer('step', torch.tensor([0], device = self.device))
# accelerator wrappers
def print(self, msg):
if not self.verbose:
return
if exists(self.accelerator):
self.accelerator.print(msg)
else:
print(msg)
def unwrap_model(self, model):
if exists(self.accelerator):
return self.accelerator.unwrap_model(model)
else:
return model
def wait_for_everyone(self):
if exists(self.accelerator):
self.accelerator.wait_for_everyone()
def is_main_process(self):
if exists(self.accelerator):
return self.accelerator.is_main_process
else:
return True
def clip_grad_norm_(self, *args):
if exists(self.accelerator):
return self.accelerator.clip_grad_norm_(*args)
else:
return torch.nn.utils.clip_grad_norm_(*args)
def backprop(self, x):
if exists(self.accelerator):
self.accelerator.backward(x)
else:
try:
x.backward()
except Exception as e:
self.print(f"Caught error in backprop call: {e}")
# utility # utility
def save(self, path, overwrite = True, **kwargs): def save(self, path, overwrite = True, **kwargs):
# ensure we sync gradients before continuing
self.wait_for_everyone()
# only save on the main process # only save on the main process
if self.is_main_process(): if self.accelerator.is_main_process:
self.print(f"Saving checkpoint at step: {self.step.item()}") print(f"Saving checkpoint at step: {self.step.item()}")
path = Path(path) path = Path(path)
assert not (path.exists() and not overwrite) assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True) path.parent.mkdir(parents = True, exist_ok = True)
# FIXME: LambdaLR can't be saved due to pickling issues
save_obj = dict( save_obj = dict(
scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(), optimizer = self.optimizer.state_dict(),
model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable warmup_scheduler = self.warmup_scheduler,
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
version = version.parse(__version__), version = version.parse(__version__),
step = self.step.item(), step = self.step,
**kwargs **kwargs
) )
@@ -320,14 +283,14 @@ class DiffusionPriorTrainer(nn.Module):
torch.save(save_obj, str(path)) torch.save(save_obj, str(path))
def load(self, path, overwrite_lr = True, strict = True): def load(self, path_or_state, overwrite_lr = True, strict = True):
""" """
Load a checkpoint of a diffusion prior trainer. Load a checkpoint of a diffusion prior trainer.
Will load the entire trainer, including the optimizer and EMA. Will load the entire trainer, including the optimizer and EMA.
Params: Params:
- path (str): a path to the DiffusionPriorTrainer checkpoint file - path_or_state (str | torch): a path to the DiffusionPriorTrainer checkpoint file
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer - overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match - strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
@@ -336,56 +299,56 @@ class DiffusionPriorTrainer(nn.Module):
""" """
# all processes need to load checkpoint. no restriction here # all processes need to load checkpoint. no restriction here
path = Path(path) if isinstance(path_or_state, str):
assert path.exists() path = Path(path)
assert path.exists()
loaded_obj = torch.load(str(path), map_location=self.device)
loaded_obj = torch.load(str(path), map_location=self.device) elif isinstance(path_or_state, dict):
loaded_obj = path_or_state
if version.parse(__version__) != loaded_obj['version']: if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}') print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
# unwrap the model when loading from checkpoint # unwrap the model when loading from checkpoint
self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict) self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
self.scaler.load_state_dict(loaded_obj['scaler'])
self.optimizer.load_state_dict(loaded_obj['optimizer']) self.optimizer.load_state_dict(loaded_obj['optimizer'])
# set warmupstep
if exists(self.warmup_scheduler):
self.warmup_scheduler.last_step = self.step.item()
# ensure new lr is used if different from old one
if overwrite_lr: if overwrite_lr:
new_lr = self.optim_kwargs["lr"] new_lr = self.optim_kwargs["lr"]
self.print(f"Overriding LR to be {new_lr}")
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group["lr"] = new_lr group["lr"] = new_lr if group["lr"] > 0.0 else 0.0
if self.use_ema: if self.use_ema:
assert 'ema' in loaded_obj assert 'ema' in loaded_obj
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict) self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
# below not be necessary, but I had a suspicion that this wasn't being loaded correctly # below might not be necessary, but I had a suspicion that this wasn't being loaded correctly
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"]) self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
# sync and inform
self.wait_for_everyone()
self.print(f"Loaded model")
return loaded_obj return loaded_obj
# model functionality # model functionality
def update(self): def update(self):
# only continue with updates until all ranks finish
self.wait_for_everyone()
if exists(self.max_grad_norm): if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer) self.accelerator.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
# utilize HFA clipping where applicable
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm) self.optimizer.step()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad() self.optimizer.zero_grad()
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
if not self.accelerator.optimizer_step_was_skipped:
with self.warmup_scheduler.dampening():
self.scheduler.step()
if self.use_ema: if self.use_ema:
self.ema_diffusion_prior.update() self.ema_diffusion_prior.update()
@@ -414,7 +377,7 @@ class DiffusionPriorTrainer(nn.Module):
@cast_torch_tensor @cast_torch_tensor
@prior_sample_in_chunks @prior_sample_in_chunks
def embed_text(self, *args, **kwargs): def embed_text(self, *args, **kwargs):
return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs) return self.accelerator.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
@cast_torch_tensor @cast_torch_tensor
def forward( def forward(
@@ -426,16 +389,14 @@ class DiffusionPriorTrainer(nn.Module):
total_loss = 0. total_loss = 0.
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp): with self.accelerator.autocast():
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs) loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
loss = loss * chunk_size_frac loss = loss * chunk_size_frac
total_loss += loss.item() total_loss += loss.item()
# backprop with accelerate if applicable
if self.training: if self.training:
self.backprop(self.scaler.scale(loss)) self.accelerator.backward(loss)
return total_loss return total_loss

View File

@@ -1,31 +1,23 @@
# TODO: add start, num_data_points, eval_every and group to config
# TODO: switch back to repo's wandb
START = 0
NUM_DATA_POINTS = 250e6
EVAL_EVERY = 1000
GROUP = "distributed"
import os
import click import click
import wandb
import torch import torch
from torch import nn from torch import nn
from torch.utils.data import DataLoader from typing import List
import numpy as np
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import set_seed
from torch.utils.data import DataLoader
from embedding_reader import EmbeddingReader
from accelerate.utils import dataclasses as accelerate_dataclasses
from dalle2_pytorch.dataloaders import get_reader, make_splits
from dalle2_pytorch.utils import Timer from dalle2_pytorch.utils import Timer
from dalle2_pytorch.trackers import Tracker
from dalle2_pytorch import DiffusionPriorTrainer
from dalle2_pytorch.dataloaders import get_reader, make_splits
from dalle2_pytorch.train_configs import ( from dalle2_pytorch.train_configs import (
DiffusionPriorConfig,
DiffusionPriorTrainConfig, DiffusionPriorTrainConfig,
TrainDiffusionPriorConfig, TrainDiffusionPriorConfig,
) )
from dalle2_pytorch.trackers import BaseTracker, WandbTracker
from dalle2_pytorch import DiffusionPriorTrainer
# helpers # helpers
@@ -38,8 +30,19 @@ def exists(val):
return val is not None return val is not None
def all_between(values: list, lower_bound, upper_bound):
for value in values:
if value < lower_bound or value > upper_bound:
return False
return True
def make_model( def make_model(
prior_config, train_config, device: str = None, accelerator: Accelerator = None prior_config: DiffusionPriorConfig,
train_config: DiffusionPriorTrainConfig,
device: str = None,
accelerator: Accelerator = None,
): ):
# create model from config # create model from config
diffusion_prior = prior_config.create() diffusion_prior = prior_config.create()
@@ -54,71 +57,214 @@ def make_model(
use_ema=train_config.use_ema, use_ema=train_config.use_ema,
device=device, device=device,
accelerator=accelerator, accelerator=accelerator,
warmup_steps=train_config.warmup_steps,
) )
return trainer return trainer
def create_tracker(
accelerator: Accelerator,
config: TrainDiffusionPriorConfig,
config_path: str,
dummy: bool = False,
) -> Tracker:
tracker_config = config.tracker
accelerator_config = {
"Distributed": accelerator.distributed_type
!= accelerate_dataclasses.DistributedType.NO,
"DistributedType": accelerator.distributed_type,
"NumProcesses": accelerator.num_processes,
"MixedPrecision": accelerator.mixed_precision,
}
tracker: Tracker = tracker_config.create(
config, accelerator_config, dummy_mode=dummy
)
tracker.save_config(config_path, config_name="prior_config.json")
return tracker
def pad_gather_reduce(trainer: DiffusionPriorTrainer, x, method="mean"):
"""
pad a value or tensor across all processes and gather
params:
- trainer: a trainer that carries an accelerator object
- x: a number or torch tensor to reduce
- method: "mean", "sum", "max", "min"
return:
- the average tensor after maskin out 0's
- None if the gather resulted in an empty tensor
"""
assert method in [
"mean",
"sum",
"max",
"min",
], "This function has limited capabilities [sum, mean, max, min]"
assert type(x) is not None, "Cannot reduce a None type object"
# wait for everyone to arrive here before gathering
if type(x) is not torch.Tensor:
x = torch.tensor([x])
# verify that the tensor is on the proper device
x = x.to(trainer.device)
# pad across processes
padded_x = trainer.accelerator.pad_across_processes(x, dim=0)
# gather across all procesess
gathered_x = trainer.accelerator.gather(padded_x)
# mask out zeros
masked_x = gathered_x[gathered_x != 0]
# if the tensor is empty, warn and return None
if len(masked_x) == 0:
click.secho(
f"The call to this method resulted in an empty tensor after masking out zeros. The gathered tensor was this: {gathered_x} and the original value passed was: {x}.",
fg="red",
)
return None
if method == "mean":
return torch.mean(masked_x)
elif method == "sum":
return torch.sum(masked_x)
elif method == "max":
return torch.max(masked_x)
elif method == "min":
return torch.min(masked_x)
def save_trainer(
tracker: Tracker,
trainer: DiffusionPriorTrainer,
is_latest: bool,
is_best: bool,
epoch: int,
samples_seen: int,
best_validation_loss: float,
):
"""
Logs the model with an appropriate method depending on the tracker
"""
trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
click.secho(
f"RANK:{trainer.accelerator.process_index} | Saving Model | Best={is_best} | Latest={is_latest}",
fg="magenta",
)
tracker.save(
trainer=trainer,
is_best=is_best,
is_latest=is_latest,
epoch=int(epoch),
samples_seen=int(samples_seen),
best_validation_loss=best_validation_loss,
)
def recall_trainer(tracker: Tracker, trainer: DiffusionPriorTrainer):
"""
Loads the model with an appropriate method depending on the tracker
"""
if trainer.accelerator.is_main_process:
click.secho(f"Loading model from {type(tracker.loader).__name__}", fg="yellow")
state_dict = tracker.recall()
trainer.load(state_dict, strict=True)
return (
int(state_dict.get("epoch", 0)),
state_dict.get("best_validation_loss", 0),
int(state_dict.get("samples_seen", 0)),
)
# eval functions # eval functions
def eval_model( def report_validation_loss(
trainer: DiffusionPriorTrainer, trainer: DiffusionPriorTrainer,
dataloader: DataLoader, dataloader: DataLoader,
text_conditioned: bool, text_conditioned: bool,
use_ema: bool,
tracker: Tracker,
split: str,
tracker_folder: str,
loss_type: str, loss_type: str,
tracker_context: str,
tracker: BaseTracker = None,
use_ema: bool = True,
): ):
trainer.eval() """
if trainer.is_main_process(): Compute the validation loss on a given subset of data.
click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True) """
with torch.no_grad(): if trainer.accelerator.is_main_process:
total_loss = 0.0 click.secho(
total_samples = 0.0 f"Measuring performance on {use_ema}-{split} split",
fg="green",
blink=True,
)
for image_embeddings, text_data in dataloader: total_loss = torch.zeros(1, dtype=torch.float, device=trainer.device)
image_embeddings = image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
batches = image_embeddings.shape[0] for image_embeddings, text_data in dataloader:
image_embeddings = image_embeddings.to(trainer.device)
text_data = text_data.to(trainer.device)
input_args = dict(image_embed=image_embeddings) input_args = dict(image_embed=image_embeddings)
if text_conditioned: if text_conditioned:
input_args = dict(**input_args, text=text_data) input_args = dict(**input_args, text=text_data)
else: else:
input_args = dict(**input_args, text_embed=text_data) input_args = dict(**input_args, text_embed=text_data)
if use_ema: if use_ema:
loss = trainer.ema_diffusion_prior(**input_args) loss = trainer.ema_diffusion_prior(**input_args)
else: else:
loss = trainer(**input_args) loss = trainer(**input_args)
total_loss += loss * batches total_loss += loss
total_samples += batches
avg_loss = total_loss / total_samples # compute the average loss across all processes
stats = {f"{tracker_context}-{loss_type}": avg_loss} avg_loss = pad_gather_reduce(trainer, total_loss, method="mean")
trainer.print(stats) stats = {f"{tracker_folder}/{loss_type}-loss": avg_loss}
if exists(tracker): # print and log results on main process
tracker.log(stats, step=trainer.step.item() + 1) tracker.log(stats, step=trainer.step.item() + 1)
return avg_loss
def report_cosine_sims( def report_cosine_sims(
trainer: DiffusionPriorTrainer, trainer: DiffusionPriorTrainer,
dataloader: DataLoader, dataloader: DataLoader,
text_conditioned: bool, text_conditioned: bool,
tracker: BaseTracker, tracker: Tracker,
tracker_context: str = "validation", split: str,
timesteps: int,
tracker_folder: str,
): ):
trainer.eval() trainer.eval()
if trainer.is_main_process(): if trainer.accelerator.is_main_process:
click.secho("Measuring Cosine-Similarity", fg="green", blink=True) click.secho(
f"Measuring Cosine-Similarity on {split} split with {timesteps} timesteps",
fg="green",
blink=True,
)
for test_image_embeddings, text_data in dataloader: for test_image_embeddings, text_data in dataloader:
test_image_embeddings = test_image_embeddings.to(trainer.device) test_image_embeddings = test_image_embeddings.to(trainer.device)
@@ -127,9 +273,7 @@ def report_cosine_sims(
# we are text conditioned, we produce an embedding from the tokenized text # we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned: if text_conditioned:
text_embedding, text_encodings = trainer.embed_text(text_data) text_embedding, text_encodings = trainer.embed_text(text_data)
text_cond = dict( text_cond = dict(text_embed=text_embedding, text_encodings=text_encodings)
text_embed=text_embedding, text_encodings=text_encodings
)
else: else:
text_embedding = text_data text_embedding = text_data
text_cond = dict(text_embed=text_embedding) text_cond = dict(text_embed=text_embedding)
@@ -150,8 +294,7 @@ def report_cosine_sims(
text_encodings_shuffled = None text_encodings_shuffled = None
text_cond_shuffled = dict( text_cond_shuffled = dict(
text_embed=text_embed_shuffled, text_embed=text_embed_shuffled, text_encodings=text_encodings_shuffled
text_encodings=text_encodings_shuffled
) )
# prepare the text embedding # prepare the text embedding
@@ -164,7 +307,9 @@ def report_cosine_sims(
# predict on the unshuffled text embeddings # predict on the unshuffled text embeddings
predicted_image_embeddings = trainer.p_sample_loop( predicted_image_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, text_cond test_image_embeddings.shape,
text_cond,
timesteps=timesteps,
) )
predicted_image_embeddings = ( predicted_image_embeddings = (
@@ -174,7 +319,9 @@ def report_cosine_sims(
# predict on the shuffled embeddings # predict on the shuffled embeddings
predicted_unrelated_embeddings = trainer.p_sample_loop( predicted_unrelated_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, text_cond_shuffled test_image_embeddings.shape,
text_cond_shuffled,
timesteps=timesteps,
) )
predicted_unrelated_embeddings = ( predicted_unrelated_embeddings = (
@@ -183,32 +330,97 @@ def report_cosine_sims(
) )
# calculate similarities # calculate similarities
original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy() orig_sim = pad_gather_reduce(
predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy() trainer, cos(text_embed, test_image_embeddings), method="mean"
unrelated_similarity = (
cos(text_embed, predicted_unrelated_embeddings).cpu().numpy()
) )
predicted_img_similarity = ( pred_sim = pad_gather_reduce(
cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy() trainer, cos(text_embed, predicted_image_embeddings), method="mean"
)
unrel_sim = pad_gather_reduce(
trainer, cos(text_embed, predicted_unrelated_embeddings), method="mean"
)
pred_img_sim = pad_gather_reduce(
trainer,
cos(test_image_embeddings, predicted_image_embeddings),
method="mean",
) )
stats = { stats = {
f"{tracker_context}/baseline similarity": np.mean(original_similarity), f"{tracker_folder}/baseline similarity [steps={timesteps}]": orig_sim,
f"{tracker_context}/similarity with text": np.mean(predicted_similarity), f"{tracker_folder}/similarity with text [steps={timesteps}]": pred_sim,
f"{tracker_context}/similarity with original image": np.mean( f"{tracker_folder}/similarity with original image [steps={timesteps}]": pred_img_sim,
predicted_img_similarity f"{tracker_folder}/similarity with unrelated caption [steps={timesteps}]": unrel_sim,
), f"{tracker_folder}/difference from baseline similarity [steps={timesteps}]": pred_sim
f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity), - orig_sim,
f"{tracker_context}/difference from baseline similarity": np.mean(
predicted_similarity - original_similarity
),
} }
for k, v in stats.items(): tracker.log(stats, step=trainer.step.item() + 1)
trainer.print(f"{tracker_context}/{k}: {v}")
if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1) def eval_model(
trainer: DiffusionPriorTrainer,
dataloader: DataLoader,
text_conditioned: bool,
split: str,
tracker: Tracker,
use_ema: bool,
report_cosine: bool,
report_loss: bool,
timesteps: List[int],
loss_type: str = None,
):
"""
Run evaluation on a model and track metrics
returns: loss if requested
"""
trainer.eval()
use_ema = "ema" if use_ema else "online"
tracker_folder = f"metrics/{use_ema}-{split}"
# detemine if valid timesteps are passed
min_timesteps = trainer.accelerator.unwrap_model(
trainer.diffusion_prior
).sample_timesteps
max_timesteps = trainer.accelerator.unwrap_model(
trainer.diffusion_prior
).noise_scheduler.num_timesteps
assert all_between(
timesteps, lower_bound=min_timesteps, upper_bound=max_timesteps
), f"all timesteps values must be between {min_timesteps} and {max_timesteps}: got {timesteps}"
# measure cosine metrics across various eta and timesteps
if report_cosine:
for timestep in timesteps:
report_cosine_sims(
trainer,
dataloader=dataloader,
text_conditioned=text_conditioned,
tracker=tracker,
split=split,
timesteps=timestep,
tracker_folder=tracker_folder,
)
# measure loss on a seperate split of data
if report_loss:
loss = report_validation_loss(
trainer=trainer,
dataloader=dataloader,
text_conditioned=text_conditioned,
use_ema=use_ema,
tracker=tracker,
split=split,
tracker_folder=tracker_folder,
loss_type=loss_type,
)
return loss
# training script # training script
@@ -216,182 +428,327 @@ def report_cosine_sims(
def train( def train(
trainer: DiffusionPriorTrainer, trainer: DiffusionPriorTrainer,
tracker: Tracker,
train_loader: DataLoader, train_loader: DataLoader,
eval_loader: DataLoader, eval_loader: DataLoader,
test_loader: DataLoader, test_loader: DataLoader,
config: DiffusionPriorTrainConfig, config: DiffusionPriorTrainConfig,
): ):
# distributed tracking with wandb # init timers
if trainer.accelerator.num_processes > 1: save_timer = Timer() # when to save
os.environ["WANDB_START_METHOD"] = "thread" samples_timer = Timer() # samples/sec
validation_profiler = Timer() # how long is validation taking
validation_countdown = Timer() # when to perform evalutation
tracker = wandb.init( # keep track of best validation loss
name=f"RANK:{trainer.device}",
entity=config.tracker.wandb_entity,
project=config.tracker.wandb_project,
config=config.dict(),
group=GROUP,
)
# sync after tracker init best_validation_loss = config.train.best_validation_loss
trainer.wait_for_everyone() samples_seen = config.train.num_samples_seen
# init a timer
timer = Timer()
# do training # do training
for img, txt in train_loader:
trainer.train()
current_step = trainer.step.item() + 1
# place data on device start_epoch = config.train.current_epoch
img = img.to(trainer.device)
txt = txt.to(trainer.device)
# pass to model for epoch in range(start_epoch, config.train.epochs):
loss = trainer(text=txt, image_embed=img) # if we finished out an old epoch, reset the distribution to be a full epoch
tracker.log({"tracking/epoch": epoch}, step=trainer.step.item())
# display & log loss (will only print from main process) if train_loader.dataset.get_start() > 0 and epoch == start_epoch+1:
trainer.print(f"Step {current_step}: Loss {loss}") if trainer.accelerator.is_main_process:
click.secho(f"Finished resumed epoch...resetting dataloader.")
train_loader.dataset.set_start(0)
# perform backprop & apply EMA updates for img, txt in train_loader:
trainer.update() # setup things every step
# track samples/sec/rank trainer.train()
samples_per_sec = img.shape[0] / timer.elapsed() current_step = trainer.step.item()
samples_timer.reset()
# samples seen # place data on device
samples_seen = (
config.data.batch_size * trainer.accelerator.num_processes * current_step
)
# ema decay img = img.to(trainer.device)
ema_decay = trainer.ema_diffusion_prior.get_current_decay() txt = txt.to(trainer.device)
# Log on all processes for debugging # pass to model
tracker.log(
{
"tracking/samples-sec": samples_per_sec,
"tracking/samples-seen": samples_seen,
"tracking/ema-decay": ema_decay,
"metrics/training-loss": loss,
},
step=current_step,
)
# Metric Tracking & Checkpointing (outside of timer's scope) loss = trainer(text=txt, image_embed=img)
if current_step % EVAL_EVERY == 0:
eval_model( # perform backprop & apply EMA updates
trainer=trainer,
dataloader=eval_loader, trainer.update()
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type, # gather info about training step
tracker_context="metrics/online-model-validation",
tracker=tracker, all_loss = pad_gather_reduce(trainer, loss, method="mean")
use_ema=False, num_samples = pad_gather_reduce(trainer, len(txt), method="sum")
samples_per_sec = num_samples / samples_timer.elapsed()
samples_seen += num_samples
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
# log
tracker.log(
{
"tracking/samples-sec": samples_per_sec,
"tracking/samples-seen": samples_seen,
"tracking/ema-decay": ema_decay,
f"tracking/training-{config.prior.loss_type}": all_loss,
},
step=current_step,
) )
eval_model( # Metric Tracking @ Timed Intervals
trainer=trainer,
dataloader=eval_loader, eval_delta = pad_gather_reduce(
text_conditioned=config.prior.condition_on_text_encodings, trainer, validation_countdown.elapsed(), method="min"
loss_type=config.prior.loss_type,
tracker_context="metrics/ema-model-validation",
tracker=tracker,
use_ema=True,
) )
report_cosine_sims( if eval_delta != None and eval_delta > config.data.eval_every_seconds:
trainer=trainer, # begin timing how long this takes
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
tracker=tracker,
tracker_context="metrics",
)
if current_step % config.train.save_every == 0: validation_profiler.reset()
trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth")
# reset timer for next round # package kwargs for evaluation
timer.reset()
eval_kwargs = {
"trainer": trainer,
"tracker": tracker,
"text_conditioned": config.prior.condition_on_text_encodings,
"timesteps": config.train.eval_timesteps,
}
# ONLINE MODEL : COSINE : LOSS : VALIDATION SPLIT
eval_model(
dataloader=eval_loader,
loss_type=config.prior.loss_type,
split="validation",
use_ema=False,
report_cosine=False,
report_loss=True,
**eval_kwargs,
)
# EMA MODEL : COSINE : LOSS : VALIDATION DATA
ema_val_loss = eval_model(
dataloader=eval_loader,
loss_type=config.prior.loss_type,
split="validation",
use_ema=True,
report_cosine=True,
report_loss=True,
**eval_kwargs,
)
tracker.log(
{
"tracking/validation length (minutes)": validation_profiler.elapsed()
/ 60
}
)
# check if the ema validation is the lowest seen yet
if ema_val_loss < best_validation_loss:
best_validation_loss = ema_val_loss
# go save the model as best
save_trainer(
trainer=trainer,
tracker=tracker,
is_best=True,
is_latest=False,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=best_validation_loss,
)
# reset timer for validaiton
validation_countdown.reset()
elif eval_delta is None:
click.secho(
f"Error occured reading the eval time on rank: {trainer.device}",
fg="yellow",
)
# save as latest model on schedule
save_delta = pad_gather_reduce(trainer, save_timer.elapsed(), method="min")
if save_delta != None and save_delta >= config.train.save_every_seconds:
save_trainer(
trainer=trainer,
tracker=tracker,
is_best=False,
is_latest=True,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=best_validation_loss,
)
save_timer.reset()
elif save_delta is None:
click.secho(
f"Error occured reading the save time on rank: {trainer.device}",
fg="yellow",
)
# evaluate on test data # evaluate on test data
eval_model( if trainer.accelerator.is_main_process:
click.secho(f"Starting Test", fg="red")
# save one last time as latest before beginning validation
save_trainer(
tracker=tracker,
trainer=trainer,
is_best=False,
is_latest=True,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=best_validation_loss,
)
test_loss = eval_model(
trainer=trainer, trainer=trainer,
dataloader=test_loader, dataloader=test_loader,
text_conditioned=config.prior.condition_on_text_encodings, text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type, split="test",
tracker_context="test",
tracker=tracker, tracker=tracker,
use_ema=True,
report_cosine=False,
report_loss=True,
timesteps=config.train.eval_timesteps,
loss_type=config.prior.loss_type,
) )
report_cosine_sims( if test_loss < best_validation_loss:
trainer, best_validation_loss = test_loss
test_loader,
config.prior.condition_on_text_encodings, # go save the model as best
tracker,
tracker_context="test", save_trainer(
) trainer=trainer,
tracker=tracker,
is_best=True,
is_latest=False,
samples_seen=samples_seen,
epoch=epoch,
best_validation_loss=test_loss,
)
def initialize_training(config, accelerator=None): def initialize_training(config_file, accelerator):
""" """
Parse the configuration file, and prepare everything necessary for training Parse the configuration file, and prepare everything necessary for training
""" """
# load the configuration file
if accelerator.is_main_process:
click.secho(f"Loading configuration from {config_file}", fg="green")
config = TrainDiffusionPriorConfig.from_json_path(config_file)
# seed
set_seed(config.train.random_seed)
# get a device # get a device
if accelerator: device = accelerator.device
device = accelerator.device
click.secho(f"Accelerating on: {device}", fg="yellow")
else:
if torch.cuda.is_available():
click.secho("GPU detected, defaulting to cuda:0", fg="yellow")
device = "cuda:0"
else:
click.secho("No GPU detected...using cpu", fg="yellow")
device = "cpu"
# make the trainer (will automatically distribute if possible & configured) # make the trainer (will automatically distribute if possible & configured)
trainer = make_model(config.prior, config.train, device, accelerator).to(device) trainer: DiffusionPriorTrainer = make_model(
config.prior, config.train, device, accelerator
).to(device)
# create a tracker
tracker = create_tracker(
accelerator, config, config_file, dummy=accelerator.process_index != 0
)
# reload from chcekpoint # reload from chcekpoint
if config.load.resume == True: if tracker.can_recall:
click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan") current_epoch, best_validation_loss, samples_seen = recall_trainer(
trainer.load(config.load.source) tracker=tracker, trainer=trainer
)
# display best values
if trainer.accelerator.is_main_process:
click.secho(f"Current Epoch: {current_epoch} | Best Val Loss: {best_validation_loss} | Samples Seen: {samples_seen}", fg="yellow")
# update config to reflect recalled values
config.train.num_samples_seen = samples_seen
config.train.current_epoch = current_epoch
config.train.best_validation_loss = best_validation_loss
# fetch and prepare data # fetch and prepare data
if trainer.is_main_process(): if trainer.accelerator.is_main_process:
click.secho("Grabbing data from source", fg="blue", blink=True) click.secho("Grabbing data...", fg="blue", blink=True)
trainer.accelerator.wait_for_everyone()
img_reader = get_reader( img_reader = get_reader(
text_conditioned=trainer.text_conditioned, text_conditioned=trainer.text_conditioned,
img_url=config.data.image_url, img_url=config.data.image_url,
meta_url=config.data.meta_url, meta_url=config.data.meta_url,
) )
# calculate start point within epoch
trainer.accelerator.wait_for_everyone()
train_loader, eval_loader, test_loader = make_splits( train_loader, eval_loader, test_loader = make_splits(
text_conditioned=trainer.text_conditioned, text_conditioned=trainer.text_conditioned,
batch_size=config.data.batch_size, batch_size=config.data.batch_size,
num_data_points=NUM_DATA_POINTS, num_data_points=config.data.num_data_points,
train_split=config.data.splits.train, train_split=config.data.splits.train,
eval_split=config.data.splits.val, eval_split=config.data.splits.val,
image_reader=img_reader, image_reader=img_reader,
rank=accelerator.state.process_index if exists(accelerator) else 0, rank=accelerator.state.process_index,
world_size=accelerator.state.num_processes if exists(accelerator) else 1, world_size=accelerator.state.num_processes,
start=START, start=0,
) )
# wait for everyone to load data before continuing # update the start point to finish out the epoch on a resumed run
trainer.wait_for_everyone()
if tracker.can_recall:
samples_seen = config.train.num_samples_seen
length = (
config.data.num_data_points
if samples_seen <= img_reader.count
else img_reader.count
)
scaled_samples = length * config.train.current_epoch
start_point = (
scaled_samples - samples_seen if scaled_samples > samples_seen else samples_seen
)
if trainer.accelerator.is_main_process:
click.secho(f"Resuming at sample: {start_point}", fg="yellow")
train_loader.dataset.set_start(start_point)
# start training # start training
if trainer.accelerator.is_main_process:
click.secho(
f"Beginning Prior Training : Distributed={accelerator.state.distributed_type != accelerate_dataclasses.DistributedType.NO}",
fg="yellow",
)
train( train(
trainer=trainer, trainer=trainer,
tracker=tracker,
train_loader=train_loader, train_loader=train_loader,
eval_loader=eval_loader, eval_loader=eval_loader,
test_loader=test_loader, test_loader=test_loader,
@@ -400,23 +757,13 @@ def initialize_training(config, accelerator=None):
@click.command() @click.command()
@click.option("--hfa", default=True) @click.option("--config_file", default="configs/train_prior_config.example.json")
@click.option("--config_path", default="configs/prior.json") def main(config_file):
def main(hfa, config_path): # start HFA
# start HFA if requested accelerator = Accelerator()
if hfa:
accelerator = Accelerator()
else:
accelerator = None
# load the configuration file on main process # setup training
if not exists(accelerator) or accelerator.is_main_process: initialize_training(config_file, accelerator)
click.secho(f"Loading configuration from {config_path}", fg="green")
config = TrainDiffusionPriorConfig.from_json_path(config_path)
# send config to get processed
initialize_training(config, accelerator)
if __name__ == "__main__": if __name__ == "__main__":