from pathlib import Path from typing import List from dalle2_pytorch.trainer import DecoderTrainer from dalle2_pytorch.dataloaders import create_image_embedding_dataloader from dalle2_pytorch.trackers import Tracker from dalle2_pytorch.train_configs import DecoderConfig, TrainDecoderConfig from dalle2_pytorch.utils import Timer, print_ribbon from dalle2_pytorch.dalle2_pytorch import Decoder, resize_image_to from clip import tokenize import torchvision import torch from torchmetrics.image.fid import FrechetInceptionDistance from torchmetrics.image.inception import InceptionScore from torchmetrics.image.kid import KernelInceptionDistance from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from accelerate import Accelerator, DistributedDataParallelKwargs from accelerate.utils import dataclasses as accelerate_dataclasses import webdataset as wds import click # constants TRAIN_CALC_LOSS_EVERY_ITERS = 10 VALID_CALC_LOSS_EVERY_ITERS = 10 # helpers functions def exists(val): return val is not None # main functions def create_dataloaders( available_shards, webdataset_base_url, img_embeddings_url=None, text_embeddings_url=None, shard_width=6, num_workers=4, batch_size=32, n_sample_images=6, shuffle_train=True, resample_train=False, img_preproc = None, index_width=4, train_prop = 0.75, val_prop = 0.15, test_prop = 0.10, seed = 0, **kwargs ): """ Randomly splits the available shards into train, val, and test sets and returns a dataloader for each """ assert train_prop + test_prop + val_prop == 1 num_train = round(train_prop*len(available_shards)) num_test = round(test_prop*len(available_shards)) num_val = len(available_shards) - num_train - num_test assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}" train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(seed)) # The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename. train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split] test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split] val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split] create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader( tar_url=tar_urls, num_workers=num_workers, batch_size=batch_size if not for_sampling else n_sample_images, img_embeddings_url=img_embeddings_url, text_embeddings_url=text_embeddings_url, index_width=index_width, shuffle_num = None, extra_keys= ["txt"], shuffle_shards = shuffle, resample_shards = resample, img_preproc=img_preproc, handler=wds.handlers.warn_and_continue ) train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train) train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True) val_dataloader = create_dataloader(val_urls, shuffle=False) test_dataloader = create_dataloader(test_urls, shuffle=False) test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True) return { "train": train_dataloader, "train_sampling": train_sampling_dataloader, "val": val_dataloader, "test": test_dataloader, "test_sampling": test_sampling_dataloader } def get_dataset_keys(dataloader): """ It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it. """ # If the dataloader is actually a WebLoader, we need to extract the real dataloader if isinstance(dataloader, wds.WebLoader): dataloader = dataloader.pipeline[0] return dataloader.dataset.key_map def get_example_data(dataloader, device, n=5): """ Samples the dataloader and returns a zipped list of examples """ images = [] img_embeddings = [] text_embeddings = [] captions = [] for img, emb, txt in dataloader: img_emb, text_emb = emb.get('img'), emb.get('text') if img_emb is not None: img_emb = img_emb.to(device=device, dtype=torch.float) img_embeddings.extend(list(img_emb)) else: # Then we add None img.shape[0] times img_embeddings.extend([None]*img.shape[0]) if text_emb is not None: text_emb = text_emb.to(device=device, dtype=torch.float) text_embeddings.extend(list(text_emb)) else: # Then we add None img.shape[0] times text_embeddings.extend([None]*img.shape[0]) img = img.to(device=device, dtype=torch.float) images.extend(list(img)) captions.extend(list(txt)) if len(images) >= n: break return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n])) def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend="", match_image_size=True): """ Takes example data and generates images from the embeddings Returns three lists: real images, generated images, and captions """ real_images, img_embeddings, text_embeddings, txts = zip(*example_data) sample_params = {} if img_embeddings[0] is None: # Generate image embeddings from clip imgs_tensor = torch.stack(real_images) img_embeddings, *_ = trainer.embed_image(imgs_tensor) sample_params["image_embed"] = img_embeddings else: # Then we are using precomputed image embeddings img_embeddings = torch.stack(img_embeddings) sample_params["image_embed"] = img_embeddings if condition_on_text_encodings: if text_embeddings[0] is None: # Generate text embeddings from text tokenized_texts = tokenize(txts, truncate=True) sample_params["text"] = tokenized_texts else: # Then we are using precomputed text embeddings text_embeddings = torch.stack(text_embeddings) sample_params["text_encodings"] = text_embeddings samples = trainer.sample(**sample_params) generated_images = list(samples) captions = [text_prepend + txt for txt in txts] if match_image_size: generated_image_size = generated_images[0].shape[-1] real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images] return real_images, generated_images, captions def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""): """ Generates samples and uses torchvision to put them in a side by side grid for easy viewing """ real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend) grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] return grid_images, captions def evaluate_trainer(trainer, dataloader, device, condition_on_text_encodings=False, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): """ Computes evaluation metrics for the decoder """ metrics = {} # Prepare the data examples = get_example_data(dataloader, device, n_evaluation_samples) if len(examples) == 0: print("No data to evaluate. Check that your dataloader has shards.") return metrics real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings) real_images = torch.stack(real_images).to(device=device, dtype=torch.float) generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float) # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8 int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8) int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8) def null_sync(t, *args, **kwargs): return [t] if exists(FID): fid = FrechetInceptionDistance(**FID, dist_sync_fn=null_sync) fid.to(device=device) fid.update(int_real_images, real=True) fid.update(int_generated_images, real=False) metrics["FID"] = fid.compute().item() if exists(IS): inception = InceptionScore(**IS, dist_sync_fn=null_sync) inception.to(device=device) inception.update(int_real_images) is_mean, is_std = inception.compute() metrics["IS_mean"] = is_mean.item() metrics["IS_std"] = is_std.item() if exists(KID): kernel_inception = KernelInceptionDistance(**KID, dist_sync_fn=null_sync) kernel_inception.to(device=device) kernel_inception.update(int_real_images, real=True) kernel_inception.update(int_generated_images, real=False) kid_mean, kid_std = kernel_inception.compute() metrics["KID_mean"] = kid_mean.item() metrics["KID_std"] = kid_std.item() if exists(LPIPS): # Convert from [0, 1] to [-1, 1] renorm_real_images = real_images.mul(2).sub(1) renorm_generated_images = generated_images.mul(2).sub(1) lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync) lpips.to(device=device) lpips.update(renorm_real_images, renorm_generated_images) metrics["LPIPS"] = lpips.compute().item() if trainer.accelerator.num_processes > 1: # Then we should sync the metrics metrics_order = sorted(metrics.keys()) metrics_tensor = torch.zeros(1, len(metrics), device=device, dtype=torch.float) for i, metric_name in enumerate(metrics_order): metrics_tensor[0, i] = metrics[metric_name] metrics_tensor = trainer.accelerator.gather(metrics_tensor) metrics_tensor = metrics_tensor.mean(dim=0) for i, metric_name in enumerate(metrics_order): metrics[metric_name] = metrics_tensor[i].item() return metrics def save_trainer(tracker: Tracker, trainer: DecoderTrainer, epoch: int, sample: int, next_task: str, validation_losses: List[float], samples_seen: int, is_latest=True, is_best=False): """ Logs the model with an appropriate method depending on the tracker """ tracker.save(trainer, is_best=is_best, is_latest=is_latest, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses, samples_seen=samples_seen) def recall_trainer(tracker: Tracker, trainer: DecoderTrainer): """ Loads the model with an appropriate method depending on the tracker """ trainer.accelerator.print(print_ribbon(f"Loading model from {type(tracker.loader).__name__}")) state_dict = tracker.recall() trainer.load_state_dict(state_dict, only_model=False, strict=True) return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0), state_dict.get("samples_seen", 0) def train( dataloaders, decoder: Decoder, accelerator: Accelerator, tracker: Tracker, inference_device, evaluate_config=None, epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch validation_samples = None, epochs = 20, n_sample_images = 5, save_every_n_samples = 100000, unet_training_mask=None, condition_on_text_encodings=False, **kwargs ): """ Trains a decoder on a dataset. """ is_master = accelerator.process_index == 0 trainer = DecoderTrainer( decoder=decoder, accelerator=accelerator, dataloaders=dataloaders, **kwargs ) # Set up starting model and parameters based on a recalled state dict start_epoch = 0 validation_losses = [] next_task = 'train' sample = 0 samples_seen = 0 val_sample = 0 if tracker.can_recall: start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer) if next_task == 'train': sample = recalled_sample if next_task == 'val': val_sample = recalled_sample accelerator.print(f"Loaded model from {type(tracker.loader).__name__} on epoch {start_epoch} having seen {samples_seen} samples with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}") accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}") trainer.to(device=inference_device) if not exists(unet_training_mask): # Then the unet mask should be true for all unets in the decoder unet_training_mask = [True] * trainer.num_unets first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask) step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1)) assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}" accelerator.print(print_ribbon("Generating Example Data", repeat=40)) accelerator.print("This can take a while to load the shard lists...") if is_master: train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images) accelerator.print("Generated training examples") test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images) accelerator.print("Generated testing examples") send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr] sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device) unet_losses_tensor = torch.zeros(TRAIN_CALC_LOSS_EVERY_ITERS, trainer.num_unets, dtype=torch.float, device=inference_device) for epoch in range(start_epoch, epochs): accelerator.print(print_ribbon(f"Starting epoch {epoch}", repeat=40)) timer = Timer() last_sample = sample last_snapshot = sample if next_task == 'train': for i, (img, emb, txt) in enumerate(trainer.train_loader): # We want to count the total number of samples across all processes sample_length_tensor[0] = len(img) all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this. total_samples = all_samples.sum().item() sample += total_samples samples_seen += total_samples img_emb = emb.get('img') has_img_embedding = img_emb is not None if has_img_embedding: img_emb, = send_to_device((img_emb,)) text_emb = emb.get('text') has_text_embedding = text_emb is not None if has_text_embedding: text_emb, = send_to_device((text_emb,)) img, = send_to_device((img,)) trainer.train() for unet in range(1, trainer.num_unets+1): # Check if this is a unet we are training if not unet_training_mask[unet-1]: # Unet index is the unet number - 1 continue forward_params = {} if has_img_embedding: forward_params['image_embed'] = img_emb else: # Forward pass automatically generates embedding pass if condition_on_text_encodings: if has_text_embedding: forward_params['text_encodings'] = text_emb else: # Then we need to pass the text instead tokenized_texts = tokenize(txt, truncate=True) forward_params['text'] = tokenized_texts loss = trainer.forward(img, **forward_params, unet_number=unet) trainer.update(unet_number=unet) unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss samples_per_sec = (sample - last_sample) / timer.elapsed() timer.reset() last_sample = sample if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0: # We want to average losses across all processes unet_all_losses = accelerator.gather(unet_losses_tensor) mask = unet_all_losses != 0 unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0) loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 } # gather decay rate on each UNet ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)} log_data = { "Epoch": epoch, "Sample": sample, "Step": i, "Samples per second": samples_per_sec, "Samples Seen": samples_seen, **ema_decay_list, **loss_map } if is_master: tracker.log(log_data, step=step()) if is_master and last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope # It is difficult to gather this kind of info on the accelerator, so we have to do it on the master print("Saving snapshot") last_snapshot = sample # We need to know where the model should be saved save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen) if exists(n_sample_images) and n_sample_images > 0: trainer.eval() train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ") tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) if epoch_samples is not None and sample >= epoch_samples: break next_task = 'val' sample = 0 all_average_val_losses = None if next_task == 'val': trainer.eval() accelerator.print(print_ribbon(f"Starting Validation {epoch}", repeat=40)) last_val_sample = val_sample val_sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device) average_val_loss_tensor = torch.zeros(1, trainer.num_unets, dtype=torch.float, device=inference_device) timer = Timer() accelerator.wait_for_everyone() i = 0 for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader val_sample_length_tensor[0] = len(img) all_samples = accelerator.gather(val_sample_length_tensor) total_samples = all_samples.sum().item() val_sample += total_samples img_emb = emb.get('img') has_img_embedding = img_emb is not None if has_img_embedding: img_emb, = send_to_device((img_emb,)) text_emb = emb.get('text') has_text_embedding = text_emb is not None if has_text_embedding: text_emb, = send_to_device((text_emb,)) img, = send_to_device((img,)) for unet in range(1, len(decoder.unets)+1): if not unet_training_mask[unet-1]: # Unet index is the unet number - 1 # No need to evaluate an unchanging unet continue forward_params = {} if has_img_embedding: forward_params['image_embed'] = img_emb.float() else: # Forward pass automatically generates embedding pass if condition_on_text_encodings: if has_text_embedding: forward_params['text_encodings'] = text_emb.float() else: # Then we need to pass the text instead tokenized_texts = tokenize(txt, truncate=True) forward_params['text'] = tokenized_texts loss = trainer.forward(img.float(), **forward_params, unet_number=unet) average_val_loss_tensor[0, unet-1] += loss if i % VALID_CALC_LOSS_EVERY_ITERS == 0: samples_per_sec = (val_sample - last_val_sample) / timer.elapsed() timer.reset() last_val_sample = val_sample accelerator.print(f"Epoch {epoch}/{epochs} Val Step {i} - Sample {val_sample} - {samples_per_sec:.2f} samples/sec") accelerator.print(f"Loss: {(average_val_loss_tensor / (i+1))}") accelerator.print("") if validation_samples is not None and val_sample >= validation_samples: break print(f"Rank {accelerator.state.process_index} finished validation after {i} steps") accelerator.wait_for_everyone() average_val_loss_tensor /= i+1 # Gather all the average loss tensors all_average_val_losses = accelerator.gather(average_val_loss_tensor) if is_master: unet_average_val_loss = all_average_val_losses.mean(dim=0) val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 } tracker.log(val_loss_map, step=step()) next_task = 'eval' if next_task == 'eval': if exists(evaluate_config): accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config.dict(), condition_on_text_encodings=condition_on_text_encodings) if is_master: tracker.log(evaluation, step=step()) next_task = 'sample' val_sample = 0 if next_task == 'sample': if is_master: # Generate examples and save the model if we are the master # Generate sample images print(print_ribbon(f"Sampling Set {epoch}", repeat=40)) test_images, test_captions = generate_grid_samples(trainer, test_example_data, condition_on_text_encodings, "Test: ") train_images, train_captions = generate_grid_samples(trainer, train_example_data, condition_on_text_encodings, "Train: ") tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step()) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step()) print(print_ribbon(f"Starting Saving {epoch}", repeat=40)) is_best = False if all_average_val_losses is not None: average_loss = all_average_val_losses.mean(dim=0).item() if len(validation_losses) == 0 or average_loss < min(validation_losses): is_best = True validation_losses.append(average_loss) save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen, is_best=is_best) next_task = 'train' def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_path: str, dummy: bool = False) -> Tracker: tracker_config = config.tracker accelerator_config = { "Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO, "DistributedType": accelerator.distributed_type, "NumProcesses": accelerator.num_processes, "MixedPrecision": accelerator.mixed_precision } tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy) tracker.save_config(config_path, config_name='decoder_config.json') return tracker def initialize_training(config: TrainDecoderConfig, config_path): # Make sure if we are not loading, distributed models are initialized to the same values torch.manual_seed(config.seed) # Set up accelerator for configurable distributed training ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) if accelerator.num_processes > 1: # We are using distributed training and want to immediately ensure all can connect accelerator.print("Waiting for all processes to connect...") accelerator.wait_for_everyone() accelerator.print("All processes online and connected") # If we are in deepspeed fp16 mode, we must ensure learned variance is off if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance: raise ValueError("DeepSpeed fp16 mode does not support learned variance") if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED: # This is an invalid configuration until we figure out how to handle this raise ValueError("DeepSpeed does not support multi-node distributed training") # Set up data all_shards = list(range(config.data.start_shard, config.data.end_shard + 1)) world_size = accelerator.num_processes rank = accelerator.process_index shards_per_process = len(all_shards) // world_size assert shards_per_process > 0, "Not enough shards to split evenly" my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process] dataloaders = create_dataloaders ( available_shards=my_shards, img_preproc = config.data.img_preproc, train_prop = config.data.splits.train, val_prop = config.data.splits.val, test_prop = config.data.splits.test, n_sample_images=config.train.n_sample_images, **config.data.dict(), rank = rank, seed = config.seed, ) # Create the decoder model and print basic info decoder = config.decoder.create() get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training)) # Create and initialize the tracker if we are the master tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0) has_img_embeddings = config.data.img_embeddings_url is not None has_text_embeddings = config.data.text_embeddings_url is not None conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets]) has_clip_model = config.decoder.clip is not None data_source_string = "" if has_img_embeddings: data_source_string += "precomputed image embeddings" elif has_clip_model: data_source_string += "clip image embeddings generation" else: raise ValueError("No image embeddings source specified") if conditioning_on_text: if has_text_embeddings: data_source_string += " and precomputed text embeddings" elif has_clip_model: data_source_string += " and clip text encoding generation" else: raise ValueError("No text embeddings source specified") accelerator.print(print_ribbon("Loaded Config", repeat=40)) accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training") accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}") accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training") for i, unet in enumerate(decoder.unets): accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training") train(dataloaders, decoder, accelerator, tracker=tracker, inference_device=accelerator.device, evaluate_config=config.evaluate, condition_on_text_encodings=conditioning_on_text, **config.train.dict(), ) # Create a simple click command line interface to load the config and start the training @click.command() @click.option("--config_file", default="./train_decoder_config.json", help="Path to config file") def main(config_file): config_file_path = Path(config_file) config = TrainDecoderConfig.from_json_path(str(config_file_path)) initialize_training(config, config_path=config_file_path) if __name__ == "__main__": main()