mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
The default save location is now none so if keys are not specified the corresponding checkpoint type is not saved. Models and checkpoints are now both saved with version number and the config used to create them in order to simplify loading. Documentation was fixed to be in line with current usage.
613 lines
30 KiB
Python
613 lines
30 KiB
Python
from pathlib import Path
|
|
from typing import List
|
|
|
|
from dalle2_pytorch.trainer import DecoderTrainer
|
|
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
|
from dalle2_pytorch.trackers import Tracker
|
|
from dalle2_pytorch.train_configs import DecoderConfig, TrainDecoderConfig
|
|
from dalle2_pytorch.utils import Timer, print_ribbon
|
|
from dalle2_pytorch.dalle2_pytorch import Decoder, resize_image_to
|
|
from clip import tokenize
|
|
|
|
import torchvision
|
|
import torch
|
|
from torchmetrics.image.fid import FrechetInceptionDistance
|
|
from torchmetrics.image.inception import InceptionScore
|
|
from torchmetrics.image.kid import KernelInceptionDistance
|
|
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
|
from accelerate import Accelerator, DistributedDataParallelKwargs
|
|
from accelerate.utils import dataclasses as accelerate_dataclasses
|
|
import webdataset as wds
|
|
import click
|
|
|
|
# constants
|
|
|
|
TRAIN_CALC_LOSS_EVERY_ITERS = 10
|
|
VALID_CALC_LOSS_EVERY_ITERS = 10
|
|
|
|
# helpers functions
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
# main functions
|
|
|
|
def create_dataloaders(
|
|
available_shards,
|
|
webdataset_base_url,
|
|
img_embeddings_url=None,
|
|
text_embeddings_url=None,
|
|
shard_width=6,
|
|
num_workers=4,
|
|
batch_size=32,
|
|
n_sample_images=6,
|
|
shuffle_train=True,
|
|
resample_train=False,
|
|
img_preproc = None,
|
|
index_width=4,
|
|
train_prop = 0.75,
|
|
val_prop = 0.15,
|
|
test_prop = 0.10,
|
|
seed = 0,
|
|
**kwargs
|
|
):
|
|
"""
|
|
Randomly splits the available shards into train, val, and test sets and returns a dataloader for each
|
|
"""
|
|
assert train_prop + test_prop + val_prop == 1
|
|
num_train = round(train_prop*len(available_shards))
|
|
num_test = round(test_prop*len(available_shards))
|
|
num_val = len(available_shards) - num_train - num_test
|
|
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
|
|
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(seed))
|
|
|
|
# The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
|
|
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
|
|
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
|
|
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
|
|
|
|
create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader(
|
|
tar_url=tar_urls,
|
|
num_workers=num_workers,
|
|
batch_size=batch_size if not for_sampling else n_sample_images,
|
|
img_embeddings_url=img_embeddings_url,
|
|
text_embeddings_url=text_embeddings_url,
|
|
index_width=index_width,
|
|
shuffle_num = None,
|
|
extra_keys= ["txt"],
|
|
shuffle_shards = shuffle,
|
|
resample_shards = resample,
|
|
img_preproc=img_preproc,
|
|
handler=wds.handlers.warn_and_continue
|
|
)
|
|
|
|
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
|
|
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
|
|
val_dataloader = create_dataloader(val_urls, shuffle=False)
|
|
test_dataloader = create_dataloader(test_urls, shuffle=False)
|
|
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
|
|
return {
|
|
"train": train_dataloader,
|
|
"train_sampling": train_sampling_dataloader,
|
|
"val": val_dataloader,
|
|
"test": test_dataloader,
|
|
"test_sampling": test_sampling_dataloader
|
|
}
|
|
|
|
def get_dataset_keys(dataloader):
|
|
"""
|
|
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
|
|
"""
|
|
# If the dataloader is actually a WebLoader, we need to extract the real dataloader
|
|
if isinstance(dataloader, wds.WebLoader):
|
|
dataloader = dataloader.pipeline[0]
|
|
return dataloader.dataset.key_map
|
|
|
|
def get_example_data(dataloader, device, n=5):
|
|
"""
|
|
Samples the dataloader and returns a zipped list of examples
|
|
"""
|
|
images = []
|
|
img_embeddings = []
|
|
text_embeddings = []
|
|
captions = []
|
|
for img, emb, txt in dataloader:
|
|
img_emb, text_emb = emb.get('img'), emb.get('text')
|
|
if img_emb is not None:
|
|
img_emb = img_emb.to(device=device, dtype=torch.float)
|
|
img_embeddings.extend(list(img_emb))
|
|
else:
|
|
# Then we add None img.shape[0] times
|
|
img_embeddings.extend([None]*img.shape[0])
|
|
if text_emb is not None:
|
|
text_emb = text_emb.to(device=device, dtype=torch.float)
|
|
text_embeddings.extend(list(text_emb))
|
|
else:
|
|
# Then we add None img.shape[0] times
|
|
text_embeddings.extend([None]*img.shape[0])
|
|
img = img.to(device=device, dtype=torch.float)
|
|
images.extend(list(img))
|
|
captions.extend(list(txt))
|
|
if len(images) >= n:
|
|
break
|
|
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
|
|
|
|
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend="", match_image_size=True):
|
|
"""
|
|
Takes example data and generates images from the embeddings
|
|
Returns three lists: real images, generated images, and captions
|
|
"""
|
|
real_images, img_embeddings, text_embeddings, txts = zip(*example_data)
|
|
sample_params = {}
|
|
if img_embeddings[0] is None:
|
|
# Generate image embeddings from clip
|
|
imgs_tensor = torch.stack(real_images)
|
|
img_embeddings, *_ = trainer.embed_image(imgs_tensor)
|
|
sample_params["image_embed"] = img_embeddings
|
|
else:
|
|
# Then we are using precomputed image embeddings
|
|
img_embeddings = torch.stack(img_embeddings)
|
|
sample_params["image_embed"] = img_embeddings
|
|
if condition_on_text_encodings:
|
|
if text_embeddings[0] is None:
|
|
# Generate text embeddings from text
|
|
tokenized_texts = tokenize(txts, truncate=True)
|
|
sample_params["text"] = tokenized_texts
|
|
else:
|
|
# Then we are using precomputed text embeddings
|
|
text_embeddings = torch.stack(text_embeddings)
|
|
sample_params["text_encodings"] = text_embeddings
|
|
samples = trainer.sample(**sample_params)
|
|
generated_images = list(samples)
|
|
captions = [text_prepend + txt for txt in txts]
|
|
if match_image_size:
|
|
generated_image_size = generated_images[0].shape[-1]
|
|
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
|
|
return real_images, generated_images, captions
|
|
|
|
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
|
|
"""
|
|
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
|
"""
|
|
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
|
|
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
|
return grid_images, captions
|
|
|
|
def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
|
"""
|
|
Computes evaluation metrics for the decoder
|
|
"""
|
|
metrics = {}
|
|
# Prepare the data
|
|
examples = get_example_data(dataloader, device, n_evaluation_samples)
|
|
if len(examples) == 0:
|
|
print("No data to evaluate. Check that your dataloader has shards.")
|
|
return metrics
|
|
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings)
|
|
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
|
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
|
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
|
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
|
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
|
|
|
def null_sync(t, *args, **kwargs):
|
|
return [t]
|
|
|
|
if exists(FID):
|
|
fid = FrechetInceptionDistance(**FID, dist_sync_fn=null_sync)
|
|
fid.to(device=device)
|
|
fid.update(int_real_images, real=True)
|
|
fid.update(int_generated_images, real=False)
|
|
metrics["FID"] = fid.compute().item()
|
|
if exists(IS):
|
|
inception = InceptionScore(**IS, dist_sync_fn=null_sync)
|
|
inception.to(device=device)
|
|
inception.update(int_real_images)
|
|
is_mean, is_std = inception.compute()
|
|
metrics["IS_mean"] = is_mean.item()
|
|
metrics["IS_std"] = is_std.item()
|
|
if exists(KID):
|
|
kernel_inception = KernelInceptionDistance(**KID, dist_sync_fn=null_sync)
|
|
kernel_inception.to(device=device)
|
|
kernel_inception.update(int_real_images, real=True)
|
|
kernel_inception.update(int_generated_images, real=False)
|
|
kid_mean, kid_std = kernel_inception.compute()
|
|
metrics["KID_mean"] = kid_mean.item()
|
|
metrics["KID_std"] = kid_std.item()
|
|
if exists(LPIPS):
|
|
# Convert from [0, 1] to [-1, 1]
|
|
renorm_real_images = real_images.mul(2).sub(1)
|
|
renorm_generated_images = generated_images.mul(2).sub(1)
|
|
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
|
|
lpips.to(device=device)
|
|
lpips.update(renorm_real_images, renorm_generated_images)
|
|
metrics["LPIPS"] = lpips.compute().item()
|
|
|
|
if trainer.accelerator.num_processes > 1:
|
|
# Then we should sync the metrics
|
|
metrics_order = sorted(metrics.keys())
|
|
metrics_tensor = torch.zeros(1, len(metrics), device=device, dtype=torch.float)
|
|
for i, metric_name in enumerate(metrics_order):
|
|
metrics_tensor[0, i] = metrics[metric_name]
|
|
metrics_tensor = trainer.accelerator.gather(metrics_tensor)
|
|
metrics_tensor = metrics_tensor.mean(dim=0)
|
|
for i, metric_name in enumerate(metrics_order):
|
|
metrics[metric_name] = metrics_tensor[i].item()
|
|
return metrics
|
|
|
|
def save_trainer(tracker: Tracker, trainer: DecoderTrainer, epoch: int, sample: int, next_task: str, validation_losses: List[float], samples_seen: int, is_latest=True, is_best=False):
|
|
"""
|
|
Logs the model with an appropriate method depending on the tracker
|
|
"""
|
|
tracker.save(trainer, is_best=is_best, is_latest=is_latest, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses, samples_seen=samples_seen)
|
|
|
|
def recall_trainer(tracker: Tracker, trainer: DecoderTrainer):
|
|
"""
|
|
Loads the model with an appropriate method depending on the tracker
|
|
"""
|
|
trainer.accelerator.print(print_ribbon(f"Loading model from {type(tracker.loader).__name__}"))
|
|
state_dict = tracker.recall()
|
|
trainer.load_state_dict(state_dict, only_model=False, strict=True)
|
|
return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0), state_dict.get("samples_seen", 0)
|
|
|
|
def train(
|
|
dataloaders,
|
|
decoder: Decoder,
|
|
accelerator: Accelerator,
|
|
tracker: Tracker,
|
|
inference_device,
|
|
evaluate_config=None,
|
|
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
|
validation_samples = None,
|
|
epochs = 20,
|
|
n_sample_images = 5,
|
|
save_every_n_samples = 100000,
|
|
unet_training_mask=None,
|
|
condition_on_text_encodings=False,
|
|
**kwargs
|
|
):
|
|
"""
|
|
Trains a decoder on a dataset.
|
|
"""
|
|
is_master = accelerator.process_index == 0
|
|
|
|
trainer = DecoderTrainer(
|
|
decoder=decoder,
|
|
accelerator=accelerator,
|
|
dataloaders=dataloaders,
|
|
**kwargs
|
|
)
|
|
|
|
# Set up starting model and parameters based on a recalled state dict
|
|
start_epoch = 0
|
|
validation_losses = []
|
|
next_task = 'train'
|
|
sample = 0
|
|
samples_seen = 0
|
|
val_sample = 0
|
|
|
|
if tracker.can_recall:
|
|
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
|
|
if next_task == 'train':
|
|
sample = recalled_sample
|
|
if next_task == 'val':
|
|
val_sample = recalled_sample
|
|
accelerator.print(f"Loaded model from {type(tracker.loader).__name__} on epoch {start_epoch} having seen {samples_seen} samples with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
|
|
accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
|
|
trainer.to(device=inference_device)
|
|
|
|
if not exists(unet_training_mask):
|
|
# Then the unet mask should be true for all unets in the decoder
|
|
unet_training_mask = [True] * trainer.num_unets
|
|
first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask)
|
|
step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1))
|
|
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
|
|
|
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
|
|
accelerator.print("This can take a while to load the shard lists...")
|
|
if is_master:
|
|
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
|
|
accelerator.print("Generated training examples")
|
|
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
|
|
accelerator.print("Generated testing examples")
|
|
|
|
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
|
|
|
|
sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
|
|
unet_losses_tensor = torch.zeros(TRAIN_CALC_LOSS_EVERY_ITERS, trainer.num_unets, dtype=torch.float, device=inference_device)
|
|
for epoch in range(start_epoch, epochs):
|
|
accelerator.print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
|
|
|
timer = Timer()
|
|
last_sample = sample
|
|
last_snapshot = sample
|
|
|
|
if next_task == 'train':
|
|
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
|
|
# We want to count the total number of samples across all processes
|
|
sample_length_tensor[0] = len(img)
|
|
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
|
total_samples = all_samples.sum().item()
|
|
sample += total_samples
|
|
samples_seen += total_samples
|
|
img_emb = emb.get('img')
|
|
has_img_embedding = img_emb is not None
|
|
if has_img_embedding:
|
|
img_emb, = send_to_device((img_emb,))
|
|
text_emb = emb.get('text')
|
|
has_text_embedding = text_emb is not None
|
|
if has_text_embedding:
|
|
text_emb, = send_to_device((text_emb,))
|
|
img, = send_to_device((img,))
|
|
|
|
trainer.train()
|
|
for unet in range(1, trainer.num_unets+1):
|
|
# Check if this is a unet we are training
|
|
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
|
continue
|
|
|
|
forward_params = {}
|
|
if has_img_embedding:
|
|
forward_params['image_embed'] = img_emb
|
|
else:
|
|
# Forward pass automatically generates embedding
|
|
pass
|
|
if condition_on_text_encodings:
|
|
if has_text_embedding:
|
|
forward_params['text_encodings'] = text_emb
|
|
else:
|
|
# Then we need to pass the text instead
|
|
tokenized_texts = tokenize(txt, truncate=True)
|
|
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
|
forward_params['text'] = tokenized_texts
|
|
loss = trainer.forward(img, **forward_params, unet_number=unet)
|
|
trainer.update(unet_number=unet)
|
|
unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
|
|
|
|
samples_per_sec = (sample - last_sample) / timer.elapsed()
|
|
timer.reset()
|
|
last_sample = sample
|
|
|
|
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
|
|
# We want to average losses across all processes
|
|
unet_all_losses = accelerator.gather(unet_losses_tensor)
|
|
mask = unet_all_losses != 0
|
|
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
|
|
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
|
|
|
|
# gather decay rate on each UNet
|
|
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
|
|
|
|
log_data = {
|
|
"Epoch": epoch,
|
|
"Sample": sample,
|
|
"Step": i,
|
|
"Samples per second": samples_per_sec,
|
|
"Samples Seen": samples_seen,
|
|
**ema_decay_list,
|
|
**loss_map
|
|
}
|
|
|
|
if is_master:
|
|
tracker.log(log_data, step=step())
|
|
|
|
if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
|
# It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
|
|
print("Saving snapshot")
|
|
last_snapshot = sample
|
|
# We need to know where the model should be saved
|
|
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
|
|
if exists(n_sample_images) and n_sample_images > 0:
|
|
trainer.eval()
|
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
|
|
|
if epoch_samples is not None and sample >= epoch_samples:
|
|
break
|
|
next_task = 'val'
|
|
sample = 0
|
|
|
|
all_average_val_losses = None
|
|
if next_task == 'val':
|
|
trainer.eval()
|
|
accelerator.print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
|
|
last_val_sample = val_sample
|
|
val_sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
|
|
average_val_loss_tensor = torch.zeros(1, trainer.num_unets, dtype=torch.float, device=inference_device)
|
|
timer = Timer()
|
|
accelerator.wait_for_everyone()
|
|
i = 0
|
|
for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader
|
|
val_sample_length_tensor[0] = len(img)
|
|
all_samples = accelerator.gather(val_sample_length_tensor)
|
|
total_samples = all_samples.sum().item()
|
|
val_sample += total_samples
|
|
img_emb = emb.get('img')
|
|
has_img_embedding = img_emb is not None
|
|
if has_img_embedding:
|
|
img_emb, = send_to_device((img_emb,))
|
|
text_emb = emb.get('text')
|
|
has_text_embedding = text_emb is not None
|
|
if has_text_embedding:
|
|
text_emb, = send_to_device((text_emb,))
|
|
img, = send_to_device((img,))
|
|
|
|
for unet in range(1, len(decoder.unets)+1):
|
|
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
|
# No need to evaluate an unchanging unet
|
|
continue
|
|
|
|
forward_params = {}
|
|
if has_img_embedding:
|
|
forward_params['image_embed'] = img_emb.float()
|
|
else:
|
|
# Forward pass automatically generates embedding
|
|
pass
|
|
if condition_on_text_encodings:
|
|
if has_text_embedding:
|
|
forward_params['text_encodings'] = text_emb.float()
|
|
else:
|
|
# Then we need to pass the text instead
|
|
tokenized_texts = tokenize(txt, truncate=True)
|
|
forward_params['text'] = tokenized_texts
|
|
loss = trainer.forward(img.float(), **forward_params, unet_number=unet)
|
|
average_val_loss_tensor[0, unet-1] += loss
|
|
|
|
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
|
samples_per_sec = (val_sample - last_val_sample) / timer.elapsed()
|
|
timer.reset()
|
|
last_val_sample = val_sample
|
|
accelerator.print(f"Epoch {epoch}/{epochs} Val Step {i} - Sample {val_sample} - {samples_per_sec:.2f} samples/sec")
|
|
accelerator.print(f"Loss: {(average_val_loss_tensor / (i+1))}")
|
|
accelerator.print("")
|
|
|
|
if validation_samples is not None and val_sample >= validation_samples:
|
|
break
|
|
print(f"Rank {accelerator.state.process_index} finished validation after {i} steps")
|
|
accelerator.wait_for_everyone()
|
|
average_val_loss_tensor /= i+1
|
|
# Gather all the average loss tensors
|
|
all_average_val_losses = accelerator.gather(average_val_loss_tensor)
|
|
if is_master:
|
|
unet_average_val_loss = all_average_val_losses.mean(dim=0)
|
|
val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
|
|
tracker.log(val_loss_map, step=step())
|
|
next_task = 'eval'
|
|
|
|
if next_task == 'eval':
|
|
if exists(evaluate_config):
|
|
accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
|
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings)
|
|
if is_master:
|
|
tracker.log(evaluation, step=step())
|
|
next_task = 'sample'
|
|
val_sample = 0
|
|
|
|
if next_task == 'sample':
|
|
if is_master:
|
|
# Generate examples and save the model if we are the master
|
|
# Generate sample images
|
|
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
|
test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ")
|
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ")
|
|
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
|
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
|
|
|
|
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
|
is_best = False
|
|
if all_average_val_losses is not None:
|
|
average_loss = all_average_val_losses.mean(dim=0).item()
|
|
if len(validation_losses) == 0 or average_loss < min(validation_losses):
|
|
is_best = True
|
|
validation_losses.append(average_loss)
|
|
save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen, is_best=is_best)
|
|
next_task = 'train'
|
|
|
|
def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_path: str, dummy: bool = False) -> Tracker:
|
|
tracker_config = config.tracker
|
|
accelerator_config = {
|
|
"Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
|
|
"DistributedType": accelerator.distributed_type,
|
|
"NumProcesses": accelerator.num_processes,
|
|
"MixedPrecision": accelerator.mixed_precision
|
|
}
|
|
tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
|
|
tracker.save_config(config_path, config_name='decoder_config.json')
|
|
tracker.add_save_metadata(state_dict_key='config', metadata=config.dict())
|
|
return tracker
|
|
|
|
def initialize_training(config: TrainDecoderConfig, config_path):
|
|
# Make sure if we are not loading, distributed models are initialized to the same values
|
|
torch.manual_seed(config.seed)
|
|
|
|
# Set up accelerator for configurable distributed training
|
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
|
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
|
|
|
if accelerator.num_processes > 1:
|
|
# We are using distributed training and want to immediately ensure all can connect
|
|
accelerator.print("Waiting for all processes to connect...")
|
|
accelerator.wait_for_everyone()
|
|
accelerator.print("All processes online and connected")
|
|
|
|
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
|
|
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
|
|
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
|
|
|
|
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
|
|
# This is an invalid configuration until we figure out how to handle this
|
|
raise ValueError("DeepSpeed does not support multi-node distributed training")
|
|
|
|
# Set up data
|
|
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
|
world_size = accelerator.num_processes
|
|
rank = accelerator.process_index
|
|
shards_per_process = len(all_shards) // world_size
|
|
assert shards_per_process > 0, "Not enough shards to split evenly"
|
|
my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process]
|
|
dataloaders = create_dataloaders (
|
|
available_shards=my_shards,
|
|
img_preproc = config.data.img_preproc,
|
|
train_prop = config.data.splits.train,
|
|
val_prop = config.data.splits.val,
|
|
test_prop = config.data.splits.test,
|
|
n_sample_images=config.train.n_sample_images,
|
|
**config.data.dict(),
|
|
rank = rank,
|
|
seed = config.seed,
|
|
)
|
|
|
|
# Create the decoder model and print basic info
|
|
decoder = config.decoder.create()
|
|
get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
|
|
|
|
# Create and initialize the tracker if we are the master
|
|
tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
|
|
|
|
has_img_embeddings = config.data.img_embeddings_url is not None
|
|
has_text_embeddings = config.data.text_embeddings_url is not None
|
|
conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
|
|
|
|
has_clip_model = config.decoder.clip is not None
|
|
data_source_string = ""
|
|
|
|
if has_img_embeddings:
|
|
data_source_string += "precomputed image embeddings"
|
|
elif has_clip_model:
|
|
data_source_string += "clip image embeddings generation"
|
|
else:
|
|
raise ValueError("No image embeddings source specified")
|
|
if conditioning_on_text:
|
|
if has_text_embeddings:
|
|
data_source_string += " and precomputed text embeddings"
|
|
elif has_clip_model:
|
|
data_source_string += " and clip text encoding generation"
|
|
else:
|
|
raise ValueError("No text embeddings source specified")
|
|
|
|
accelerator.print(print_ribbon("Loaded Config", repeat=40))
|
|
accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
|
|
accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
|
|
accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training")
|
|
for i, unet in enumerate(decoder.unets):
|
|
accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
|
|
|
|
train(dataloaders, decoder, accelerator,
|
|
tracker=tracker,
|
|
inference_device=accelerator.device,
|
|
evaluate_config=config.evaluate,
|
|
condition_on_text_encodings=conditioning_on_text,
|
|
**config.train.dict(),
|
|
)
|
|
|
|
# Create a simple click command line interface to load the config and start the training
|
|
@click.command()
|
|
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
|
def main(config_file):
|
|
config_file_path = Path(config_file)
|
|
config = TrainDecoderConfig.from_json_path(str(config_file_path))
|
|
initialize_training(config, config_path=config_file_path)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|