mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
Prior updates (#211)
* update configs for prior add prior warmup to config update example prior config * update prior trainer & script add deepspeed amp & warmup adopt full accelerator support reload at sample point finish epoch resume code * update tracker save method for prior * helper functions for prior_loader
This commit is contained in:
@@ -1,18 +1,14 @@
|
||||
{
|
||||
"prior": {
|
||||
"clip": {
|
||||
"make": "x-clip",
|
||||
"model": "ViT-L/14",
|
||||
"base_model_kwargs": {
|
||||
"dim_text": 768,
|
||||
"dim_image": 768,
|
||||
"dim_latent": 768
|
||||
}
|
||||
"make": "openai",
|
||||
"model": "ViT-L/14"
|
||||
},
|
||||
"net": {
|
||||
"dim": 768,
|
||||
"depth": 12,
|
||||
"num_timesteps": 1000,
|
||||
"max_text_len": 77,
|
||||
"num_time_embeds": 1,
|
||||
"num_image_embeds": 1,
|
||||
"num_text_embeds": 1,
|
||||
@@ -20,8 +16,8 @@
|
||||
"heads": 12,
|
||||
"ff_mult": 4,
|
||||
"norm_out": true,
|
||||
"attn_dropout": 0.0,
|
||||
"ff_dropout": 0.0,
|
||||
"attn_dropout": 0.05,
|
||||
"ff_dropout": 0.05,
|
||||
"final_proj": true,
|
||||
"normformer": true,
|
||||
"rotary_emb": true
|
||||
@@ -30,6 +26,7 @@
|
||||
"image_size": 224,
|
||||
"image_channels": 3,
|
||||
"timesteps": 1000,
|
||||
"sample_timesteps": 64,
|
||||
"cond_drop_prob": 0.1,
|
||||
"loss_type": "l2",
|
||||
"predict_x_start": true,
|
||||
@@ -37,34 +34,48 @@
|
||||
"condition_on_text_encodings": true
|
||||
},
|
||||
"data": {
|
||||
"image_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/",
|
||||
"text_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/",
|
||||
"meta_url": "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/laion2B-en-metadata/",
|
||||
"batch_size": 256,
|
||||
"batch_size": 128,
|
||||
"num_data_points": 100000,
|
||||
"eval_every_seconds": 1600,
|
||||
"image_url": "<path to your images>",
|
||||
"meta_url": "<path to your metadata>",
|
||||
"splits": {
|
||||
"train": 0.9,
|
||||
"val": 1e-7,
|
||||
"test": 0.0999999
|
||||
"train": 0.8,
|
||||
"val": 0.1,
|
||||
"test": 0.1
|
||||
}
|
||||
},
|
||||
"train": {
|
||||
"epochs": 1,
|
||||
"epochs": 5,
|
||||
"lr": 1.1e-4,
|
||||
"wd": 6.02e-2,
|
||||
"max_grad_norm": 0.5,
|
||||
"use_ema": true,
|
||||
"ema_beta": 0.9999,
|
||||
"ema_update_after_step": 50,
|
||||
"warmup_steps": 50,
|
||||
"amp": false,
|
||||
"save_every": 10000
|
||||
},
|
||||
"load": {
|
||||
"source": null,
|
||||
"resume": false
|
||||
"save_every_seconds": 3600,
|
||||
"eval_timesteps": [64, 1000],
|
||||
"random_seed": 84513
|
||||
},
|
||||
"tracker": {
|
||||
"tracker_type": "wandb",
|
||||
"data_path": "./prior_checkpoints",
|
||||
"wandb_entity": "laion",
|
||||
"wandb_project": "diffusion-prior",
|
||||
"data_path": ".prior",
|
||||
"overwrite_data_path": true,
|
||||
"log": {
|
||||
"log_type": "wandb",
|
||||
"wandb_entity": "<your wandb username>",
|
||||
"wandb_project": "prior_debugging",
|
||||
"wandb_resume": false,
|
||||
"verbose": true
|
||||
},
|
||||
"save": [
|
||||
{
|
||||
"save_to": "local",
|
||||
"save_type": "checkpoint",
|
||||
"save_latest_to": ".prior/latest_checkpoint.pth",
|
||||
"save_best_to": ".prior/best_checkpoint.pth"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,6 +67,15 @@ class PriorEmbeddingDataset(IterableDataset):
|
||||
def __str__(self):
|
||||
return f"<PriorEmbeddingDataset: start: {self.start}, stop: {self.stop}, len: {self.__len__()}>"
|
||||
|
||||
def set_start(self, start):
|
||||
"""
|
||||
Adjust the starting point within the reader, useful for resuming an epoch
|
||||
"""
|
||||
self.start = start
|
||||
|
||||
def get_start(self):
|
||||
return self.start
|
||||
|
||||
def get_sample(self):
|
||||
"""
|
||||
pre-proocess data from either reader into a common format
|
||||
|
||||
@@ -528,12 +528,8 @@ class Tracker:
|
||||
elif save_type == 'model':
|
||||
if isinstance(trainer, DiffusionPriorTrainer):
|
||||
prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
|
||||
prior: DiffusionPrior = trainer.unwrap_model(prior)
|
||||
# Remove CLIP if it is part of the model
|
||||
original_clip = prior.clip
|
||||
prior.clip = None
|
||||
model_state_dict = prior.state_dict()
|
||||
prior.clip = original_clip
|
||||
state_dict = trainer.accelerator.unwrap_model(prior).state_dict()
|
||||
torch.save(state_dict, file_path)
|
||||
elif isinstance(trainer, DecoderTrainer):
|
||||
decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
|
||||
# Remove CLIP if it is part of the model
|
||||
|
||||
@@ -145,6 +145,9 @@ class DiffusionPriorNetworkConfig(BaseModel):
|
||||
normformer: bool = False
|
||||
rotary_emb: bool = True
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
def create(self):
|
||||
kwargs = self.dict()
|
||||
return DiffusionPriorNetwork(**kwargs)
|
||||
@@ -187,23 +190,26 @@ class DiffusionPriorTrainConfig(BaseModel):
|
||||
use_ema: bool = True
|
||||
ema_beta: float = 0.99
|
||||
amp: bool = False
|
||||
save_every: int = 10000 # what steps to save on
|
||||
warmup_steps: int = None # number of warmup steps
|
||||
save_every_seconds: int = 3600 # how often to save
|
||||
eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with
|
||||
best_validation_loss: float = 1e9 # the current best valudation loss observed
|
||||
current_epoch: int = 0 # the current epoch
|
||||
num_samples_seen: int = 0 # the current number of samples seen
|
||||
random_seed: int = 0 # manual seed for torch
|
||||
|
||||
class DiffusionPriorDataConfig(BaseModel):
|
||||
image_url: str # path to embeddings folder
|
||||
meta_url: str # path to metadata (captions) for images
|
||||
splits: TrainSplitConfig
|
||||
batch_size: int = 64
|
||||
|
||||
class DiffusionPriorLoadConfig(BaseModel):
|
||||
source: str = None
|
||||
resume: bool = False
|
||||
splits: TrainSplitConfig # define train, validation, test splits for your dataset
|
||||
batch_size: int # per-gpu batch size used to train the model
|
||||
num_data_points: int = 25e7 # total number of datapoints to train on
|
||||
eval_every_seconds: int = 3600 # validation statistics will be performed this often
|
||||
|
||||
class TrainDiffusionPriorConfig(BaseModel):
|
||||
prior: DiffusionPriorConfig
|
||||
data: DiffusionPriorDataConfig
|
||||
train: DiffusionPriorTrainConfig
|
||||
load: DiffusionPriorLoadConfig
|
||||
tracker: TrackerConfig
|
||||
|
||||
@classmethod
|
||||
@@ -323,12 +329,6 @@ class DecoderEvaluateConfig(BaseModel):
|
||||
KID: Dict[str, Any] = None
|
||||
LPIPS: Dict[str, Any] = None
|
||||
|
||||
class DecoderLoadConfig(BaseModel):
|
||||
source: str = None # Supports file and wandb
|
||||
run_path: str = '' # Used only if source is wandb
|
||||
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
|
||||
resume: bool = False # If using wandb, whether to resume the run
|
||||
|
||||
class TrainDecoderConfig(BaseModel):
|
||||
decoder: DecoderConfig
|
||||
data: DecoderDataConfig
|
||||
|
||||
@@ -174,27 +174,21 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_prior,
|
||||
accelerator,
|
||||
use_ema = True,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-6,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
device = None,
|
||||
accelerator = None,
|
||||
verbose = True,
|
||||
warmup_steps = 1,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||
assert not exists(accelerator) or isinstance(accelerator, Accelerator)
|
||||
assert isinstance(accelerator, Accelerator)
|
||||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||
|
||||
# verbosity
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
# assign some helpful member vars
|
||||
|
||||
self.accelerator = accelerator
|
||||
@@ -202,23 +196,31 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
# setting the device
|
||||
|
||||
if not exists(accelerator) and not exists(device):
|
||||
diffusion_prior_device = next(diffusion_prior.parameters()).device
|
||||
self.print(f'accelerator not given, and device not specified: defaulting to device of diffusion prior parameters - {diffusion_prior_device}')
|
||||
self.device = diffusion_prior_device
|
||||
else:
|
||||
self.device = accelerator.device if exists(accelerator) else device
|
||||
self.device = accelerator.device
|
||||
diffusion_prior.to(self.device)
|
||||
|
||||
# save model
|
||||
|
||||
self.diffusion_prior = diffusion_prior
|
||||
|
||||
# optimizer and mixed precision stuff
|
||||
# mixed precision checks
|
||||
|
||||
self.amp = amp
|
||||
if (
|
||||
exists(self.accelerator)
|
||||
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
|
||||
and self.diffusion_prior.clip is not None
|
||||
):
|
||||
# Then we need to make sure clip is using the correct precision or else deepspeed will error
|
||||
cast_type_map = {
|
||||
"fp16": torch.half,
|
||||
"bf16": torch.bfloat16,
|
||||
"no": torch.float
|
||||
}
|
||||
precision_type = cast_type_map[accelerator.mixed_precision]
|
||||
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
|
||||
self.diffusion_prior.clip.to(precision_type)
|
||||
|
||||
self.scaler = GradScaler(enabled = amp)
|
||||
# optimizer stuff
|
||||
|
||||
self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
|
||||
|
||||
@@ -228,16 +230,20 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
||||
|
||||
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
|
||||
|
||||
# distribute the model if using HFA
|
||||
if exists(self.accelerator):
|
||||
self.diffusion_prior, self.optimizer = self.accelerator.prepare(self.diffusion_prior, self.optimizer)
|
||||
|
||||
self.diffusion_prior, self.optimizer, self.scheduler = self.accelerator.prepare(self.diffusion_prior, self.optimizer, self.scheduler)
|
||||
|
||||
# exponential moving average stuff
|
||||
|
||||
self.use_ema = use_ema
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior = EMA(self.unwrap_model(self.diffusion_prior), **ema_kwargs)
|
||||
self.ema_diffusion_prior = EMA(self.accelerator.unwrap_model(self.diffusion_prior), **ema_kwargs)
|
||||
|
||||
# gradient clipping if needed
|
||||
|
||||
@@ -247,67 +253,24 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
self.register_buffer('step', torch.tensor([0], device = self.device))
|
||||
|
||||
# accelerator wrappers
|
||||
|
||||
def print(self, msg):
|
||||
if not self.verbose:
|
||||
return
|
||||
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.print(msg)
|
||||
else:
|
||||
print(msg)
|
||||
|
||||
def unwrap_model(self, model):
|
||||
if exists(self.accelerator):
|
||||
return self.accelerator.unwrap_model(model)
|
||||
else:
|
||||
return model
|
||||
|
||||
def wait_for_everyone(self):
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
def is_main_process(self):
|
||||
if exists(self.accelerator):
|
||||
return self.accelerator.is_main_process
|
||||
else:
|
||||
return True
|
||||
|
||||
def clip_grad_norm_(self, *args):
|
||||
if exists(self.accelerator):
|
||||
return self.accelerator.clip_grad_norm_(*args)
|
||||
else:
|
||||
return torch.nn.utils.clip_grad_norm_(*args)
|
||||
|
||||
def backprop(self, x):
|
||||
if exists(self.accelerator):
|
||||
self.accelerator.backward(x)
|
||||
else:
|
||||
try:
|
||||
x.backward()
|
||||
except Exception as e:
|
||||
self.print(f"Caught error in backprop call: {e}")
|
||||
|
||||
# utility
|
||||
|
||||
def save(self, path, overwrite = True, **kwargs):
|
||||
# ensure we sync gradients before continuing
|
||||
self.wait_for_everyone()
|
||||
|
||||
# only save on the main process
|
||||
if self.is_main_process():
|
||||
self.print(f"Saving checkpoint at step: {self.step.item()}")
|
||||
if self.accelerator.is_main_process:
|
||||
print(f"Saving checkpoint at step: {self.step.item()}")
|
||||
path = Path(path)
|
||||
assert not (path.exists() and not overwrite)
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
# FIXME: LambdaLR can't be saved due to pickling issues
|
||||
save_obj = dict(
|
||||
scaler = self.scaler.state_dict(),
|
||||
optimizer = self.optimizer.state_dict(),
|
||||
model = self.unwrap_model(self.diffusion_prior).state_dict(), # unwrap the model from distribution if applicable
|
||||
warmup_scheduler = self.warmup_scheduler,
|
||||
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
|
||||
version = version.parse(__version__),
|
||||
step = self.step.item(),
|
||||
step = self.step,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -320,14 +283,14 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
torch.save(save_obj, str(path))
|
||||
|
||||
def load(self, path, overwrite_lr = True, strict = True):
|
||||
def load(self, path_or_state, overwrite_lr = True, strict = True):
|
||||
"""
|
||||
Load a checkpoint of a diffusion prior trainer.
|
||||
|
||||
Will load the entire trainer, including the optimizer and EMA.
|
||||
|
||||
Params:
|
||||
- path (str): a path to the DiffusionPriorTrainer checkpoint file
|
||||
- path_or_state (str | torch): a path to the DiffusionPriorTrainer checkpoint file
|
||||
- overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
|
||||
- strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
|
||||
|
||||
@@ -336,56 +299,56 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
"""
|
||||
|
||||
# all processes need to load checkpoint. no restriction here
|
||||
if isinstance(path_or_state, str):
|
||||
path = Path(path)
|
||||
assert path.exists()
|
||||
|
||||
loaded_obj = torch.load(str(path), map_location=self.device)
|
||||
|
||||
elif isinstance(path_or_state, dict):
|
||||
loaded_obj = path_or_state
|
||||
|
||||
if version.parse(__version__) != loaded_obj['version']:
|
||||
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
|
||||
|
||||
# unwrap the model when loading from checkpoint
|
||||
self.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||
|
||||
self.scaler.load_state_dict(loaded_obj['scaler'])
|
||||
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
|
||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
|
||||
# set warmupstep
|
||||
if exists(self.warmup_scheduler):
|
||||
self.warmup_scheduler.last_step = self.step.item()
|
||||
|
||||
# ensure new lr is used if different from old one
|
||||
if overwrite_lr:
|
||||
new_lr = self.optim_kwargs["lr"]
|
||||
|
||||
self.print(f"Overriding LR to be {new_lr}")
|
||||
|
||||
for group in self.optimizer.param_groups:
|
||||
group["lr"] = new_lr
|
||||
group["lr"] = new_lr if group["lr"] > 0.0 else 0.0
|
||||
|
||||
if self.use_ema:
|
||||
assert 'ema' in loaded_obj
|
||||
self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
|
||||
# below not be necessary, but I had a suspicion that this wasn't being loaded correctly
|
||||
# below might not be necessary, but I had a suspicion that this wasn't being loaded correctly
|
||||
self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
|
||||
|
||||
# sync and inform
|
||||
self.wait_for_everyone()
|
||||
self.print(f"Loaded model")
|
||||
|
||||
return loaded_obj
|
||||
|
||||
# model functionality
|
||||
|
||||
def update(self):
|
||||
# only continue with updates until all ranks finish
|
||||
self.wait_for_everyone()
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
# utilize HFA clipping where applicable
|
||||
self.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
||||
self.accelerator.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
||||
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
||||
if not self.accelerator.optimizer_step_was_skipped:
|
||||
with self.warmup_scheduler.dampening():
|
||||
self.scheduler.step()
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior.update()
|
||||
|
||||
@@ -414,7 +377,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def embed_text(self, *args, **kwargs):
|
||||
return self.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
|
||||
return self.accelerator.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
|
||||
|
||||
@cast_torch_tensor
|
||||
def forward(
|
||||
@@ -426,16 +389,14 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
total_loss = 0.
|
||||
|
||||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
|
||||
with autocast(enabled = self.amp):
|
||||
with self.accelerator.autocast():
|
||||
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
|
||||
loss = loss * chunk_size_frac
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
# backprop with accelerate if applicable
|
||||
|
||||
if self.training:
|
||||
self.backprop(self.scaler.scale(loss))
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
@@ -1,31 +1,23 @@
|
||||
# TODO: add start, num_data_points, eval_every and group to config
|
||||
# TODO: switch back to repo's wandb
|
||||
|
||||
START = 0
|
||||
NUM_DATA_POINTS = 250e6
|
||||
EVAL_EVERY = 1000
|
||||
GROUP = "distributed"
|
||||
|
||||
import os
|
||||
import click
|
||||
import wandb
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import numpy as np
|
||||
|
||||
from typing import List
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
from torch.utils.data import DataLoader
|
||||
from embedding_reader import EmbeddingReader
|
||||
from accelerate.utils import dataclasses as accelerate_dataclasses
|
||||
|
||||
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
||||
from dalle2_pytorch.utils import Timer
|
||||
from dalle2_pytorch.trackers import Tracker
|
||||
from dalle2_pytorch import DiffusionPriorTrainer
|
||||
from dalle2_pytorch.dataloaders import get_reader, make_splits
|
||||
from dalle2_pytorch.train_configs import (
|
||||
DiffusionPriorConfig,
|
||||
DiffusionPriorTrainConfig,
|
||||
TrainDiffusionPriorConfig,
|
||||
)
|
||||
from dalle2_pytorch.trackers import BaseTracker, WandbTracker
|
||||
from dalle2_pytorch import DiffusionPriorTrainer
|
||||
|
||||
|
||||
# helpers
|
||||
@@ -38,8 +30,19 @@ def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def all_between(values: list, lower_bound, upper_bound):
|
||||
for value in values:
|
||||
if value < lower_bound or value > upper_bound:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def make_model(
|
||||
prior_config, train_config, device: str = None, accelerator: Accelerator = None
|
||||
prior_config: DiffusionPriorConfig,
|
||||
train_config: DiffusionPriorTrainConfig,
|
||||
device: str = None,
|
||||
accelerator: Accelerator = None,
|
||||
):
|
||||
# create model from config
|
||||
diffusion_prior = prior_config.create()
|
||||
@@ -54,37 +57,173 @@ def make_model(
|
||||
use_ema=train_config.use_ema,
|
||||
device=device,
|
||||
accelerator=accelerator,
|
||||
warmup_steps=train_config.warmup_steps,
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
def create_tracker(
|
||||
accelerator: Accelerator,
|
||||
config: TrainDiffusionPriorConfig,
|
||||
config_path: str,
|
||||
dummy: bool = False,
|
||||
) -> Tracker:
|
||||
tracker_config = config.tracker
|
||||
|
||||
accelerator_config = {
|
||||
"Distributed": accelerator.distributed_type
|
||||
!= accelerate_dataclasses.DistributedType.NO,
|
||||
"DistributedType": accelerator.distributed_type,
|
||||
"NumProcesses": accelerator.num_processes,
|
||||
"MixedPrecision": accelerator.mixed_precision,
|
||||
}
|
||||
|
||||
tracker: Tracker = tracker_config.create(
|
||||
config, accelerator_config, dummy_mode=dummy
|
||||
)
|
||||
|
||||
tracker.save_config(config_path, config_name="prior_config.json")
|
||||
|
||||
return tracker
|
||||
|
||||
|
||||
def pad_gather_reduce(trainer: DiffusionPriorTrainer, x, method="mean"):
|
||||
"""
|
||||
pad a value or tensor across all processes and gather
|
||||
|
||||
params:
|
||||
- trainer: a trainer that carries an accelerator object
|
||||
- x: a number or torch tensor to reduce
|
||||
- method: "mean", "sum", "max", "min"
|
||||
|
||||
return:
|
||||
- the average tensor after maskin out 0's
|
||||
- None if the gather resulted in an empty tensor
|
||||
"""
|
||||
|
||||
assert method in [
|
||||
"mean",
|
||||
"sum",
|
||||
"max",
|
||||
"min",
|
||||
], "This function has limited capabilities [sum, mean, max, min]"
|
||||
assert type(x) is not None, "Cannot reduce a None type object"
|
||||
|
||||
# wait for everyone to arrive here before gathering
|
||||
|
||||
if type(x) is not torch.Tensor:
|
||||
x = torch.tensor([x])
|
||||
|
||||
# verify that the tensor is on the proper device
|
||||
x = x.to(trainer.device)
|
||||
|
||||
# pad across processes
|
||||
padded_x = trainer.accelerator.pad_across_processes(x, dim=0)
|
||||
|
||||
# gather across all procesess
|
||||
gathered_x = trainer.accelerator.gather(padded_x)
|
||||
|
||||
# mask out zeros
|
||||
masked_x = gathered_x[gathered_x != 0]
|
||||
|
||||
# if the tensor is empty, warn and return None
|
||||
if len(masked_x) == 0:
|
||||
click.secho(
|
||||
f"The call to this method resulted in an empty tensor after masking out zeros. The gathered tensor was this: {gathered_x} and the original value passed was: {x}.",
|
||||
fg="red",
|
||||
)
|
||||
return None
|
||||
|
||||
if method == "mean":
|
||||
return torch.mean(masked_x)
|
||||
elif method == "sum":
|
||||
return torch.sum(masked_x)
|
||||
elif method == "max":
|
||||
return torch.max(masked_x)
|
||||
elif method == "min":
|
||||
return torch.min(masked_x)
|
||||
|
||||
|
||||
def save_trainer(
|
||||
tracker: Tracker,
|
||||
trainer: DiffusionPriorTrainer,
|
||||
is_latest: bool,
|
||||
is_best: bool,
|
||||
epoch: int,
|
||||
samples_seen: int,
|
||||
best_validation_loss: float,
|
||||
):
|
||||
"""
|
||||
Logs the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(
|
||||
f"RANK:{trainer.accelerator.process_index} | Saving Model | Best={is_best} | Latest={is_latest}",
|
||||
fg="magenta",
|
||||
)
|
||||
|
||||
tracker.save(
|
||||
trainer=trainer,
|
||||
is_best=is_best,
|
||||
is_latest=is_latest,
|
||||
epoch=int(epoch),
|
||||
samples_seen=int(samples_seen),
|
||||
best_validation_loss=best_validation_loss,
|
||||
)
|
||||
|
||||
|
||||
def recall_trainer(tracker: Tracker, trainer: DiffusionPriorTrainer):
|
||||
"""
|
||||
Loads the model with an appropriate method depending on the tracker
|
||||
"""
|
||||
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(f"Loading model from {type(tracker.loader).__name__}", fg="yellow")
|
||||
|
||||
state_dict = tracker.recall()
|
||||
|
||||
trainer.load(state_dict, strict=True)
|
||||
|
||||
return (
|
||||
int(state_dict.get("epoch", 0)),
|
||||
state_dict.get("best_validation_loss", 0),
|
||||
int(state_dict.get("samples_seen", 0)),
|
||||
)
|
||||
|
||||
|
||||
# eval functions
|
||||
|
||||
|
||||
def eval_model(
|
||||
def report_validation_loss(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
dataloader: DataLoader,
|
||||
text_conditioned: bool,
|
||||
use_ema: bool,
|
||||
tracker: Tracker,
|
||||
split: str,
|
||||
tracker_folder: str,
|
||||
loss_type: str,
|
||||
tracker_context: str,
|
||||
tracker: BaseTracker = None,
|
||||
use_ema: bool = True,
|
||||
):
|
||||
trainer.eval()
|
||||
if trainer.is_main_process():
|
||||
click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True)
|
||||
"""
|
||||
Compute the validation loss on a given subset of data.
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
total_loss = 0.0
|
||||
total_samples = 0.0
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(
|
||||
f"Measuring performance on {use_ema}-{split} split",
|
||||
fg="green",
|
||||
blink=True,
|
||||
)
|
||||
|
||||
total_loss = torch.zeros(1, dtype=torch.float, device=trainer.device)
|
||||
|
||||
for image_embeddings, text_data in dataloader:
|
||||
image_embeddings = image_embeddings.to(trainer.device)
|
||||
text_data = text_data.to(trainer.device)
|
||||
|
||||
batches = image_embeddings.shape[0]
|
||||
|
||||
input_args = dict(image_embed=image_embeddings)
|
||||
|
||||
if text_conditioned:
|
||||
@@ -97,28 +236,35 @@ def eval_model(
|
||||
else:
|
||||
loss = trainer(**input_args)
|
||||
|
||||
total_loss += loss * batches
|
||||
total_samples += batches
|
||||
total_loss += loss
|
||||
|
||||
avg_loss = total_loss / total_samples
|
||||
# compute the average loss across all processes
|
||||
|
||||
stats = {f"{tracker_context}-{loss_type}": avg_loss}
|
||||
trainer.print(stats)
|
||||
avg_loss = pad_gather_reduce(trainer, total_loss, method="mean")
|
||||
stats = {f"{tracker_folder}/{loss_type}-loss": avg_loss}
|
||||
|
||||
if exists(tracker):
|
||||
# print and log results on main process
|
||||
tracker.log(stats, step=trainer.step.item() + 1)
|
||||
|
||||
return avg_loss
|
||||
|
||||
|
||||
def report_cosine_sims(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
dataloader: DataLoader,
|
||||
text_conditioned: bool,
|
||||
tracker: BaseTracker,
|
||||
tracker_context: str = "validation",
|
||||
tracker: Tracker,
|
||||
split: str,
|
||||
timesteps: int,
|
||||
tracker_folder: str,
|
||||
):
|
||||
trainer.eval()
|
||||
if trainer.is_main_process():
|
||||
click.secho("Measuring Cosine-Similarity", fg="green", blink=True)
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(
|
||||
f"Measuring Cosine-Similarity on {split} split with {timesteps} timesteps",
|
||||
fg="green",
|
||||
blink=True,
|
||||
)
|
||||
|
||||
for test_image_embeddings, text_data in dataloader:
|
||||
test_image_embeddings = test_image_embeddings.to(trainer.device)
|
||||
@@ -127,9 +273,7 @@ def report_cosine_sims(
|
||||
# we are text conditioned, we produce an embedding from the tokenized text
|
||||
if text_conditioned:
|
||||
text_embedding, text_encodings = trainer.embed_text(text_data)
|
||||
text_cond = dict(
|
||||
text_embed=text_embedding, text_encodings=text_encodings
|
||||
)
|
||||
text_cond = dict(text_embed=text_embedding, text_encodings=text_encodings)
|
||||
else:
|
||||
text_embedding = text_data
|
||||
text_cond = dict(text_embed=text_embedding)
|
||||
@@ -150,8 +294,7 @@ def report_cosine_sims(
|
||||
text_encodings_shuffled = None
|
||||
|
||||
text_cond_shuffled = dict(
|
||||
text_embed=text_embed_shuffled,
|
||||
text_encodings=text_encodings_shuffled
|
||||
text_embed=text_embed_shuffled, text_encodings=text_encodings_shuffled
|
||||
)
|
||||
|
||||
# prepare the text embedding
|
||||
@@ -164,7 +307,9 @@ def report_cosine_sims(
|
||||
|
||||
# predict on the unshuffled text embeddings
|
||||
predicted_image_embeddings = trainer.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond
|
||||
test_image_embeddings.shape,
|
||||
text_cond,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
predicted_image_embeddings = (
|
||||
@@ -174,7 +319,9 @@ def report_cosine_sims(
|
||||
|
||||
# predict on the shuffled embeddings
|
||||
predicted_unrelated_embeddings = trainer.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond_shuffled
|
||||
test_image_embeddings.shape,
|
||||
text_cond_shuffled,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
predicted_unrelated_embeddings = (
|
||||
@@ -183,215 +330,425 @@ def report_cosine_sims(
|
||||
)
|
||||
|
||||
# calculate similarities
|
||||
original_similarity = cos(text_embed, test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(text_embed, predicted_image_embeddings).cpu().numpy()
|
||||
unrelated_similarity = (
|
||||
cos(text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||
orig_sim = pad_gather_reduce(
|
||||
trainer, cos(text_embed, test_image_embeddings), method="mean"
|
||||
)
|
||||
predicted_img_similarity = (
|
||||
cos(test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||
pred_sim = pad_gather_reduce(
|
||||
trainer, cos(text_embed, predicted_image_embeddings), method="mean"
|
||||
)
|
||||
unrel_sim = pad_gather_reduce(
|
||||
trainer, cos(text_embed, predicted_unrelated_embeddings), method="mean"
|
||||
)
|
||||
pred_img_sim = pad_gather_reduce(
|
||||
trainer,
|
||||
cos(test_image_embeddings, predicted_image_embeddings),
|
||||
method="mean",
|
||||
)
|
||||
|
||||
stats = {
|
||||
f"{tracker_context}/baseline similarity": np.mean(original_similarity),
|
||||
f"{tracker_context}/similarity with text": np.mean(predicted_similarity),
|
||||
f"{tracker_context}/similarity with original image": np.mean(
|
||||
predicted_img_similarity
|
||||
),
|
||||
f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity),
|
||||
f"{tracker_context}/difference from baseline similarity": np.mean(
|
||||
predicted_similarity - original_similarity
|
||||
),
|
||||
f"{tracker_folder}/baseline similarity [steps={timesteps}]": orig_sim,
|
||||
f"{tracker_folder}/similarity with text [steps={timesteps}]": pred_sim,
|
||||
f"{tracker_folder}/similarity with original image [steps={timesteps}]": pred_img_sim,
|
||||
f"{tracker_folder}/similarity with unrelated caption [steps={timesteps}]": unrel_sim,
|
||||
f"{tracker_folder}/difference from baseline similarity [steps={timesteps}]": pred_sim
|
||||
- orig_sim,
|
||||
}
|
||||
|
||||
for k, v in stats.items():
|
||||
trainer.print(f"{tracker_context}/{k}: {v}")
|
||||
|
||||
if exists(tracker):
|
||||
tracker.log(stats, step=trainer.step.item() + 1)
|
||||
|
||||
|
||||
def eval_model(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
dataloader: DataLoader,
|
||||
text_conditioned: bool,
|
||||
split: str,
|
||||
tracker: Tracker,
|
||||
use_ema: bool,
|
||||
report_cosine: bool,
|
||||
report_loss: bool,
|
||||
timesteps: List[int],
|
||||
loss_type: str = None,
|
||||
):
|
||||
"""
|
||||
Run evaluation on a model and track metrics
|
||||
|
||||
returns: loss if requested
|
||||
"""
|
||||
trainer.eval()
|
||||
|
||||
use_ema = "ema" if use_ema else "online"
|
||||
tracker_folder = f"metrics/{use_ema}-{split}"
|
||||
|
||||
# detemine if valid timesteps are passed
|
||||
|
||||
min_timesteps = trainer.accelerator.unwrap_model(
|
||||
trainer.diffusion_prior
|
||||
).sample_timesteps
|
||||
max_timesteps = trainer.accelerator.unwrap_model(
|
||||
trainer.diffusion_prior
|
||||
).noise_scheduler.num_timesteps
|
||||
|
||||
assert all_between(
|
||||
timesteps, lower_bound=min_timesteps, upper_bound=max_timesteps
|
||||
), f"all timesteps values must be between {min_timesteps} and {max_timesteps}: got {timesteps}"
|
||||
|
||||
# measure cosine metrics across various eta and timesteps
|
||||
|
||||
if report_cosine:
|
||||
for timestep in timesteps:
|
||||
report_cosine_sims(
|
||||
trainer,
|
||||
dataloader=dataloader,
|
||||
text_conditioned=text_conditioned,
|
||||
tracker=tracker,
|
||||
split=split,
|
||||
timesteps=timestep,
|
||||
tracker_folder=tracker_folder,
|
||||
)
|
||||
|
||||
# measure loss on a seperate split of data
|
||||
|
||||
if report_loss:
|
||||
loss = report_validation_loss(
|
||||
trainer=trainer,
|
||||
dataloader=dataloader,
|
||||
text_conditioned=text_conditioned,
|
||||
use_ema=use_ema,
|
||||
tracker=tracker,
|
||||
split=split,
|
||||
tracker_folder=tracker_folder,
|
||||
loss_type=loss_type,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
# training script
|
||||
|
||||
|
||||
def train(
|
||||
trainer: DiffusionPriorTrainer,
|
||||
tracker: Tracker,
|
||||
train_loader: DataLoader,
|
||||
eval_loader: DataLoader,
|
||||
test_loader: DataLoader,
|
||||
config: DiffusionPriorTrainConfig,
|
||||
):
|
||||
# distributed tracking with wandb
|
||||
if trainer.accelerator.num_processes > 1:
|
||||
os.environ["WANDB_START_METHOD"] = "thread"
|
||||
# init timers
|
||||
save_timer = Timer() # when to save
|
||||
samples_timer = Timer() # samples/sec
|
||||
validation_profiler = Timer() # how long is validation taking
|
||||
validation_countdown = Timer() # when to perform evalutation
|
||||
|
||||
tracker = wandb.init(
|
||||
name=f"RANK:{trainer.device}",
|
||||
entity=config.tracker.wandb_entity,
|
||||
project=config.tracker.wandb_project,
|
||||
config=config.dict(),
|
||||
group=GROUP,
|
||||
)
|
||||
# keep track of best validation loss
|
||||
|
||||
# sync after tracker init
|
||||
trainer.wait_for_everyone()
|
||||
|
||||
# init a timer
|
||||
timer = Timer()
|
||||
best_validation_loss = config.train.best_validation_loss
|
||||
samples_seen = config.train.num_samples_seen
|
||||
|
||||
# do training
|
||||
|
||||
start_epoch = config.train.current_epoch
|
||||
|
||||
for epoch in range(start_epoch, config.train.epochs):
|
||||
# if we finished out an old epoch, reset the distribution to be a full epoch
|
||||
tracker.log({"tracking/epoch": epoch}, step=trainer.step.item())
|
||||
|
||||
if train_loader.dataset.get_start() > 0 and epoch == start_epoch+1:
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(f"Finished resumed epoch...resetting dataloader.")
|
||||
train_loader.dataset.set_start(0)
|
||||
|
||||
for img, txt in train_loader:
|
||||
# setup things every step
|
||||
|
||||
trainer.train()
|
||||
current_step = trainer.step.item() + 1
|
||||
current_step = trainer.step.item()
|
||||
samples_timer.reset()
|
||||
|
||||
# place data on device
|
||||
|
||||
img = img.to(trainer.device)
|
||||
txt = txt.to(trainer.device)
|
||||
|
||||
# pass to model
|
||||
|
||||
loss = trainer(text=txt, image_embed=img)
|
||||
|
||||
# display & log loss (will only print from main process)
|
||||
trainer.print(f"Step {current_step}: Loss {loss}")
|
||||
|
||||
# perform backprop & apply EMA updates
|
||||
|
||||
trainer.update()
|
||||
|
||||
# track samples/sec/rank
|
||||
samples_per_sec = img.shape[0] / timer.elapsed()
|
||||
# gather info about training step
|
||||
|
||||
# samples seen
|
||||
samples_seen = (
|
||||
config.data.batch_size * trainer.accelerator.num_processes * current_step
|
||||
)
|
||||
|
||||
# ema decay
|
||||
all_loss = pad_gather_reduce(trainer, loss, method="mean")
|
||||
num_samples = pad_gather_reduce(trainer, len(txt), method="sum")
|
||||
samples_per_sec = num_samples / samples_timer.elapsed()
|
||||
samples_seen += num_samples
|
||||
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
|
||||
|
||||
# Log on all processes for debugging
|
||||
# log
|
||||
|
||||
tracker.log(
|
||||
{
|
||||
"tracking/samples-sec": samples_per_sec,
|
||||
"tracking/samples-seen": samples_seen,
|
||||
"tracking/ema-decay": ema_decay,
|
||||
"metrics/training-loss": loss,
|
||||
f"tracking/training-{config.prior.loss_type}": all_loss,
|
||||
},
|
||||
step=current_step,
|
||||
)
|
||||
|
||||
# Metric Tracking & Checkpointing (outside of timer's scope)
|
||||
if current_step % EVAL_EVERY == 0:
|
||||
# Metric Tracking @ Timed Intervals
|
||||
|
||||
eval_delta = pad_gather_reduce(
|
||||
trainer, validation_countdown.elapsed(), method="min"
|
||||
)
|
||||
|
||||
if eval_delta != None and eval_delta > config.data.eval_every_seconds:
|
||||
# begin timing how long this takes
|
||||
|
||||
validation_profiler.reset()
|
||||
|
||||
# package kwargs for evaluation
|
||||
|
||||
eval_kwargs = {
|
||||
"trainer": trainer,
|
||||
"tracker": tracker,
|
||||
"text_conditioned": config.prior.condition_on_text_encodings,
|
||||
"timesteps": config.train.eval_timesteps,
|
||||
}
|
||||
|
||||
# ONLINE MODEL : COSINE : LOSS : VALIDATION SPLIT
|
||||
|
||||
eval_model(
|
||||
trainer=trainer,
|
||||
dataloader=eval_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
loss_type=config.prior.loss_type,
|
||||
tracker_context="metrics/online-model-validation",
|
||||
tracker=tracker,
|
||||
split="validation",
|
||||
use_ema=False,
|
||||
report_cosine=False,
|
||||
report_loss=True,
|
||||
**eval_kwargs,
|
||||
)
|
||||
|
||||
eval_model(
|
||||
trainer=trainer,
|
||||
# EMA MODEL : COSINE : LOSS : VALIDATION DATA
|
||||
|
||||
ema_val_loss = eval_model(
|
||||
dataloader=eval_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
loss_type=config.prior.loss_type,
|
||||
tracker_context="metrics/ema-model-validation",
|
||||
tracker=tracker,
|
||||
split="validation",
|
||||
use_ema=True,
|
||||
report_cosine=True,
|
||||
report_loss=True,
|
||||
**eval_kwargs,
|
||||
)
|
||||
|
||||
report_cosine_sims(
|
||||
tracker.log(
|
||||
{
|
||||
"tracking/validation length (minutes)": validation_profiler.elapsed()
|
||||
/ 60
|
||||
}
|
||||
)
|
||||
|
||||
# check if the ema validation is the lowest seen yet
|
||||
|
||||
if ema_val_loss < best_validation_loss:
|
||||
best_validation_loss = ema_val_loss
|
||||
|
||||
# go save the model as best
|
||||
|
||||
save_trainer(
|
||||
trainer=trainer,
|
||||
dataloader=eval_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
tracker=tracker,
|
||||
tracker_context="metrics",
|
||||
is_best=True,
|
||||
is_latest=False,
|
||||
samples_seen=samples_seen,
|
||||
epoch=epoch,
|
||||
best_validation_loss=best_validation_loss,
|
||||
)
|
||||
|
||||
if current_step % config.train.save_every == 0:
|
||||
trainer.save(f"{config.tracker.data_path}/chkpt_step_{current_step}.pth")
|
||||
# reset timer for validaiton
|
||||
|
||||
# reset timer for next round
|
||||
timer.reset()
|
||||
validation_countdown.reset()
|
||||
|
||||
elif eval_delta is None:
|
||||
click.secho(
|
||||
f"Error occured reading the eval time on rank: {trainer.device}",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
# save as latest model on schedule
|
||||
|
||||
save_delta = pad_gather_reduce(trainer, save_timer.elapsed(), method="min")
|
||||
|
||||
if save_delta != None and save_delta >= config.train.save_every_seconds:
|
||||
save_trainer(
|
||||
trainer=trainer,
|
||||
tracker=tracker,
|
||||
is_best=False,
|
||||
is_latest=True,
|
||||
samples_seen=samples_seen,
|
||||
epoch=epoch,
|
||||
best_validation_loss=best_validation_loss,
|
||||
)
|
||||
|
||||
save_timer.reset()
|
||||
|
||||
elif save_delta is None:
|
||||
click.secho(
|
||||
f"Error occured reading the save time on rank: {trainer.device}",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
# evaluate on test data
|
||||
|
||||
eval_model(
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(f"Starting Test", fg="red")
|
||||
|
||||
# save one last time as latest before beginning validation
|
||||
|
||||
save_trainer(
|
||||
tracker=tracker,
|
||||
trainer=trainer,
|
||||
is_best=False,
|
||||
is_latest=True,
|
||||
samples_seen=samples_seen,
|
||||
epoch=epoch,
|
||||
best_validation_loss=best_validation_loss,
|
||||
)
|
||||
|
||||
test_loss = eval_model(
|
||||
trainer=trainer,
|
||||
dataloader=test_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
loss_type=config.prior.loss_type,
|
||||
tracker_context="test",
|
||||
split="test",
|
||||
tracker=tracker,
|
||||
use_ema=True,
|
||||
report_cosine=False,
|
||||
report_loss=True,
|
||||
timesteps=config.train.eval_timesteps,
|
||||
loss_type=config.prior.loss_type,
|
||||
)
|
||||
|
||||
report_cosine_sims(
|
||||
trainer,
|
||||
test_loader,
|
||||
config.prior.condition_on_text_encodings,
|
||||
tracker,
|
||||
tracker_context="test",
|
||||
if test_loss < best_validation_loss:
|
||||
best_validation_loss = test_loss
|
||||
|
||||
# go save the model as best
|
||||
|
||||
save_trainer(
|
||||
trainer=trainer,
|
||||
tracker=tracker,
|
||||
is_best=True,
|
||||
is_latest=False,
|
||||
samples_seen=samples_seen,
|
||||
epoch=epoch,
|
||||
best_validation_loss=test_loss,
|
||||
)
|
||||
|
||||
|
||||
def initialize_training(config, accelerator=None):
|
||||
def initialize_training(config_file, accelerator):
|
||||
"""
|
||||
Parse the configuration file, and prepare everything necessary for training
|
||||
"""
|
||||
# load the configuration file
|
||||
if accelerator.is_main_process:
|
||||
click.secho(f"Loading configuration from {config_file}", fg="green")
|
||||
|
||||
config = TrainDiffusionPriorConfig.from_json_path(config_file)
|
||||
|
||||
# seed
|
||||
|
||||
set_seed(config.train.random_seed)
|
||||
|
||||
# get a device
|
||||
|
||||
if accelerator:
|
||||
device = accelerator.device
|
||||
click.secho(f"Accelerating on: {device}", fg="yellow")
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
click.secho("GPU detected, defaulting to cuda:0", fg="yellow")
|
||||
device = "cuda:0"
|
||||
else:
|
||||
click.secho("No GPU detected...using cpu", fg="yellow")
|
||||
device = "cpu"
|
||||
|
||||
# make the trainer (will automatically distribute if possible & configured)
|
||||
|
||||
trainer = make_model(config.prior, config.train, device, accelerator).to(device)
|
||||
trainer: DiffusionPriorTrainer = make_model(
|
||||
config.prior, config.train, device, accelerator
|
||||
).to(device)
|
||||
|
||||
# create a tracker
|
||||
|
||||
tracker = create_tracker(
|
||||
accelerator, config, config_file, dummy=accelerator.process_index != 0
|
||||
)
|
||||
|
||||
# reload from chcekpoint
|
||||
|
||||
if config.load.resume == True:
|
||||
click.secho(f"Loading checkpoint: {config.load.source}", fg="cyan")
|
||||
trainer.load(config.load.source)
|
||||
if tracker.can_recall:
|
||||
current_epoch, best_validation_loss, samples_seen = recall_trainer(
|
||||
tracker=tracker, trainer=trainer
|
||||
)
|
||||
|
||||
# display best values
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(f"Current Epoch: {current_epoch} | Best Val Loss: {best_validation_loss} | Samples Seen: {samples_seen}", fg="yellow")
|
||||
|
||||
# update config to reflect recalled values
|
||||
config.train.num_samples_seen = samples_seen
|
||||
config.train.current_epoch = current_epoch
|
||||
config.train.best_validation_loss = best_validation_loss
|
||||
|
||||
# fetch and prepare data
|
||||
|
||||
if trainer.is_main_process():
|
||||
click.secho("Grabbing data from source", fg="blue", blink=True)
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho("Grabbing data...", fg="blue", blink=True)
|
||||
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
img_reader = get_reader(
|
||||
text_conditioned=trainer.text_conditioned,
|
||||
img_url=config.data.image_url,
|
||||
meta_url=config.data.meta_url,
|
||||
)
|
||||
|
||||
# calculate start point within epoch
|
||||
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
|
||||
train_loader, eval_loader, test_loader = make_splits(
|
||||
text_conditioned=trainer.text_conditioned,
|
||||
batch_size=config.data.batch_size,
|
||||
num_data_points=NUM_DATA_POINTS,
|
||||
num_data_points=config.data.num_data_points,
|
||||
train_split=config.data.splits.train,
|
||||
eval_split=config.data.splits.val,
|
||||
image_reader=img_reader,
|
||||
rank=accelerator.state.process_index if exists(accelerator) else 0,
|
||||
world_size=accelerator.state.num_processes if exists(accelerator) else 1,
|
||||
start=START,
|
||||
rank=accelerator.state.process_index,
|
||||
world_size=accelerator.state.num_processes,
|
||||
start=0,
|
||||
)
|
||||
|
||||
# wait for everyone to load data before continuing
|
||||
trainer.wait_for_everyone()
|
||||
# update the start point to finish out the epoch on a resumed run
|
||||
|
||||
if tracker.can_recall:
|
||||
samples_seen = config.train.num_samples_seen
|
||||
length = (
|
||||
config.data.num_data_points
|
||||
if samples_seen <= img_reader.count
|
||||
else img_reader.count
|
||||
)
|
||||
scaled_samples = length * config.train.current_epoch
|
||||
start_point = (
|
||||
scaled_samples - samples_seen if scaled_samples > samples_seen else samples_seen
|
||||
)
|
||||
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(f"Resuming at sample: {start_point}", fg="yellow")
|
||||
|
||||
train_loader.dataset.set_start(start_point)
|
||||
|
||||
# start training
|
||||
|
||||
if trainer.accelerator.is_main_process:
|
||||
click.secho(
|
||||
f"Beginning Prior Training : Distributed={accelerator.state.distributed_type != accelerate_dataclasses.DistributedType.NO}",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
train(
|
||||
trainer=trainer,
|
||||
tracker=tracker,
|
||||
train_loader=train_loader,
|
||||
eval_loader=eval_loader,
|
||||
test_loader=test_loader,
|
||||
@@ -400,23 +757,13 @@ def initialize_training(config, accelerator=None):
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--hfa", default=True)
|
||||
@click.option("--config_path", default="configs/prior.json")
|
||||
def main(hfa, config_path):
|
||||
# start HFA if requested
|
||||
if hfa:
|
||||
@click.option("--config_file", default="configs/train_prior_config.example.json")
|
||||
def main(config_file):
|
||||
# start HFA
|
||||
accelerator = Accelerator()
|
||||
else:
|
||||
accelerator = None
|
||||
|
||||
# load the configuration file on main process
|
||||
if not exists(accelerator) or accelerator.is_main_process:
|
||||
click.secho(f"Loading configuration from {config_path}", fg="green")
|
||||
|
||||
config = TrainDiffusionPriorConfig.from_json_path(config_path)
|
||||
|
||||
# send config to get processed
|
||||
initialize_training(config, accelerator)
|
||||
# setup training
|
||||
initialize_training(config_file, accelerator)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user